mirror of https://github.com/langgenius/dify.git
Compare commits
290 Commits
main
...
0.10.0-bet
| Author | SHA1 | Date |
|---|---|---|
|
|
afb3e317d3 | |
|
|
1e62ad23e6 | |
|
|
f9baabc9f9 | |
|
|
94b946c715 | |
|
|
3ff88f4f3c | |
|
|
e7f425be91 | |
|
|
584053bc01 | |
|
|
44e81dbbc8 | |
|
|
944cfd2b68 | |
|
|
6d2682c751 | |
|
|
d2971e84bb | |
|
|
b67b81bf8f | |
|
|
c05902404d | |
|
|
1e6d5f2c48 | |
|
|
e2b1464db2 | |
|
|
f0285a53d2 | |
|
|
00f91b5dc4 | |
|
|
dc3e86b82a | |
|
|
d239c5b54d | |
|
|
23abccd3a6 | |
|
|
2520e40059 | |
|
|
1d4ed3d9e7 | |
|
|
eed8ab9348 | |
|
|
112aaf6e1b | |
|
|
094a1a1458 | |
|
|
955fa4345a | |
|
|
ac5e381a1a | |
|
|
ae9b9f867a | |
|
|
8fd04e5313 | |
|
|
3904782647 | |
|
|
288be3fbd8 | |
|
|
f7f836d6f1 | |
|
|
5dedcb74a5 | |
|
|
b95d0fa9a9 | |
|
|
543503c398 | |
|
|
3f16caf244 | |
|
|
54133dfbde | |
|
|
b491c93b1c | |
|
|
2a6d9c3211 | |
|
|
c6691bd297 | |
|
|
2a0b30de5c | |
|
|
a7d53abba9 | |
|
|
296253a365 | |
|
|
c89cefe526 | |
|
|
1d027fa065 | |
|
|
9ce9a52a86 | |
|
|
c74424ed85 | |
|
|
719ef9cef9 | |
|
|
0ab525a691 | |
|
|
6fdcf6ee21 | |
|
|
d01e97c1fc | |
|
|
87e560de8a | |
|
|
f8d26e46ac | |
|
|
195ac19774 | |
|
|
0281eb796d | |
|
|
9fe2f321ae | |
|
|
5f76e665a1 | |
|
|
81568752c0 | |
|
|
ceb1dde714 | |
|
|
3209fdca53 | |
|
|
dc5010d833 | |
|
|
8b26ae6532 | |
|
|
66953d57a2 | |
|
|
afc9630cd0 | |
|
|
7e8bafe186 | |
|
|
6c5fcd1ffc | |
|
|
7602d22133 | |
|
|
5ec91e8507 | |
|
|
466966f027 | |
|
|
212d04ea27 | |
|
|
0cb50dd4a5 | |
|
|
ab19fccf3d | |
|
|
4ed46e3fed | |
|
|
9fd2f798ff | |
|
|
146be41b1d | |
|
|
ce6ae5732a | |
|
|
edf462c640 | |
|
|
d580fc1e9d | |
|
|
5544791031 | |
|
|
099746dd59 | |
|
|
c6f53c9030 | |
|
|
8236f8fed8 | |
|
|
2b0c39ed3f | |
|
|
396a240e68 | |
|
|
8bd9d8f6ba | |
|
|
aa7ae4c5f1 | |
|
|
49b7acf52e | |
|
|
466ac987f5 | |
|
|
49972939a9 | |
|
|
80f167ca02 | |
|
|
f652ae0d98 | |
|
|
4dbf56675a | |
|
|
f5d1f5a20a | |
|
|
fd9b71c4d7 | |
|
|
1df41cef4c | |
|
|
602d2486bd | |
|
|
403fede432 | |
|
|
9f66e6e357 | |
|
|
affb2e38a1 | |
|
|
31d87f85b8 | |
|
|
54105e85ff | |
|
|
5ec604500c | |
|
|
96d2582d89 | |
|
|
a10b0db102 | |
|
|
5dd556b4c8 | |
|
|
a4c6d0b94b | |
|
|
323a835de9 | |
|
|
0076577764 | |
|
|
9a3b7345c4 | |
|
|
2ebf5f5ffa | |
|
|
02f494c0de | |
|
|
f0e81e3918 | |
|
|
aa8499efac | |
|
|
ea40b1dcb2 | |
|
|
a689cd6fd4 | |
|
|
32b6c7063a | |
|
|
97056dad30 | |
|
|
264f7c2139 | |
|
|
007a6fd14a | |
|
|
c159b7a781 | |
|
|
6c9c3faf78 | |
|
|
d933ebb845 | |
|
|
b60c7a5826 | |
|
|
0b94218378 | |
|
|
97cc9a5615 | |
|
|
f6d0fd9848 | |
|
|
b863dd7de2 | |
|
|
b0e7a22a27 | |
|
|
565a835947 | |
|
|
fe94c876fb | |
|
|
67a34bdd7a | |
|
|
8c785e268b | |
|
|
65a6265ff6 | |
|
|
08d3cb1912 | |
|
|
48d8b01d81 | |
|
|
38edb06897 | |
|
|
dc919c2a6c | |
|
|
e7a6a0ab01 | |
|
|
61d989f413 | |
|
|
976efd93a1 | |
|
|
0e2f78b3a6 | |
|
|
b3529d3ccc | |
|
|
d69b453729 | |
|
|
2f658de155 | |
|
|
a691700b48 | |
|
|
c5317d8f58 | |
|
|
822f03f3cd | |
|
|
101e56baaa | |
|
|
3a8f516dfc | |
|
|
912030c9a1 | |
|
|
687661eef7 | |
|
|
8efc63a705 | |
|
|
dca4f9fe9c | |
|
|
51597629b1 | |
|
|
76a07513ba | |
|
|
dae62bef78 | |
|
|
2a6629d435 | |
|
|
41f0ce1012 | |
|
|
e90b055c47 | |
|
|
94e40d4ed9 | |
|
|
c34fc071e0 | |
|
|
c014ae43e1 | |
|
|
9851153d38 | |
|
|
cfbabb8383 | |
|
|
b78e90679d | |
|
|
ec1bfdc723 | |
|
|
e20019f6e9 | |
|
|
2122cfb152 | |
|
|
c2b8beffac | |
|
|
985651454a | |
|
|
f9c1d06e91 | |
|
|
657f1d2de8 | |
|
|
6e2192c1e0 | |
|
|
e05b20eb91 | |
|
|
5117e08def | |
|
|
34691ca6c9 | |
|
|
aa40047b08 | |
|
|
eca17767fe | |
|
|
51cec1b9ba | |
|
|
651547c3ef | |
|
|
8fbdaa604c | |
|
|
1bcb30647f | |
|
|
bc245a25bf | |
|
|
85b25ebe1b | |
|
|
b50e94d681 | |
|
|
91c0657cf6 | |
|
|
0da06128e3 | |
|
|
0c4af3a1d2 | |
|
|
5628b293f8 | |
|
|
fff40aae58 | |
|
|
b3b87b3e4c | |
|
|
9a23cd08d8 | |
|
|
cf61ca24e3 | |
|
|
58a56add9c | |
|
|
b362031baf | |
|
|
7ad409b3d9 | |
|
|
876ea90fe9 | |
|
|
0eb442f954 | |
|
|
4554ac3ef8 | |
|
|
eaa7d114dc | |
|
|
581228be74 | |
|
|
02da0219ff | |
|
|
d0bbe43dab | |
|
|
16acdc9be4 | |
|
|
a6999b5d02 | |
|
|
33bfa4758e | |
|
|
db63c2c219 | |
|
|
bea4ec5998 | |
|
|
74333db4c8 | |
|
|
0019fb9f8b | |
|
|
47615ac8fb | |
|
|
d7c8bced9b | |
|
|
57f178902f | |
|
|
4586de48d6 | |
|
|
6549519fa5 | |
|
|
ae098ad121 | |
|
|
20922fde1c | |
|
|
079c802b5c | |
|
|
efcd462a69 | |
|
|
843c8ad306 | |
|
|
594bf96922 | |
|
|
ade385c9c1 | |
|
|
baed068231 | |
|
|
42f5334ae4 | |
|
|
3c4ab0632d | |
|
|
bc5f109308 | |
|
|
97b2a42cc3 | |
|
|
939df16655 | |
|
|
9362ae045c | |
|
|
257c515178 | |
|
|
6b7520ccc2 | |
|
|
85eeaee95a | |
|
|
99bf3ff565 | |
|
|
36ae154ca2 | |
|
|
ef93d60534 | |
|
|
6c9a6b99e0 | |
|
|
b73f05fdf0 | |
|
|
26bca75884 | |
|
|
e2962da1b8 | |
|
|
1b9ebb8037 | |
|
|
a945a45b06 | |
|
|
be829a8103 | |
|
|
9432d41e60 | |
|
|
0beeb4ab3e | |
|
|
d7e057be44 | |
|
|
81b11c08d0 | |
|
|
83a5cdfff9 | |
|
|
c837218bc9 | |
|
|
68552893ef | |
|
|
5ba93ed064 | |
|
|
959107f553 | |
|
|
443d929137 | |
|
|
1e04418023 | |
|
|
aeda8869bc | |
|
|
10eed02ec4 | |
|
|
2472c4f890 | |
|
|
0455e4e1a5 | |
|
|
251ab5418f | |
|
|
38e6e40900 | |
|
|
b3a3672857 | |
|
|
53a3c199ec | |
|
|
fca5af5073 | |
|
|
77d0aac1d3 | |
|
|
fd0f8f33b5 | |
|
|
0be99ad01c | |
|
|
a05d16375e | |
|
|
0480bb03c3 | |
|
|
19dfc6d9a8 | |
|
|
d361675159 | |
|
|
23ae150298 | |
|
|
81383d7c74 | |
|
|
573f653789 | |
|
|
f1b61861b6 | |
|
|
8ecee8abce | |
|
|
e9ce9c1f47 | |
|
|
944fea4cc9 | |
|
|
25c029877a | |
|
|
9c31c56115 | |
|
|
56507c9f7a | |
|
|
b322dda3f6 | |
|
|
52d69dd55b | |
|
|
0451c5590c | |
|
|
2498c238b2 | |
|
|
6e15d7f777 | |
|
|
f6caf0915b | |
|
|
09aa14ca82 | |
|
|
394f06a27a | |
|
|
6fafd410d2 | |
|
|
1668df104f | |
|
|
d376b8540e |
|
|
@ -201,6 +201,8 @@ OPENSEARCH_SECURE=true
|
||||||
UPLOAD_FILE_SIZE_LIMIT=15
|
UPLOAD_FILE_SIZE_LIMIT=15
|
||||||
UPLOAD_FILE_BATCH_LIMIT=5
|
UPLOAD_FILE_BATCH_LIMIT=5
|
||||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||||
|
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||||
|
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||||
|
|
||||||
# Model Configuration
|
# Model Configuration
|
||||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||||
|
|
@ -275,6 +277,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||||
WORKFLOW_CALL_MAX_DEPTH=5
|
WORKFLOW_CALL_MAX_DEPTH=5
|
||||||
|
MAX_VARIABLE_SIZE=204800
|
||||||
|
|
||||||
# App configuration
|
# App configuration
|
||||||
APP_MAX_EXECUTION_TIME=1200
|
APP_MAX_EXECUTION_TIME=1200
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,15 @@
|
||||||
{
|
{
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
|
"compounds": [
|
||||||
|
{
|
||||||
|
"name": "Launch Flask and Celery",
|
||||||
|
"configurations": ["Python: Flask", "Python: Celery"]
|
||||||
|
}
|
||||||
|
],
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Python: Flask",
|
"name": "Python: Flask",
|
||||||
|
"consoleName": "Flask",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"python": "${workspaceFolder}/.venv/bin/python",
|
"python": "${workspaceFolder}/.venv/bin/python",
|
||||||
|
|
@ -17,12 +24,12 @@
|
||||||
},
|
},
|
||||||
"args": [
|
"args": [
|
||||||
"run",
|
"run",
|
||||||
"--host=0.0.0.0",
|
|
||||||
"--port=5001"
|
"--port=5001"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Celery",
|
"name": "Python: Celery",
|
||||||
|
"consoleName": "Celery",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"python": "${workspaceFolder}/.venv/bin/python",
|
"python": "${workspaceFolder}/.venv/bin/python",
|
||||||
|
|
@ -45,10 +52,10 @@
|
||||||
"-c",
|
"-c",
|
||||||
"1",
|
"1",
|
||||||
"--loglevel",
|
"--loglevel",
|
||||||
"info",
|
"DEBUG",
|
||||||
"-Q",
|
"-Q",
|
||||||
"dataset,generation,mail,ops_trace,app_deletion"
|
"dataset,generation,mail,ops_trace,app_deletion"
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ def create_app() -> Flask:
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=app.config.get("LOG_LEVEL"),
|
level=app.config.get("LOG_LEVEL"),
|
||||||
format=app.config.get("LOG_FORMAT"),
|
format=app.config["LOG_FORMAT"],
|
||||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||||
handlers=log_handlers,
|
handlers=log_handlers,
|
||||||
force=True,
|
force=True,
|
||||||
|
|
@ -134,6 +134,7 @@ def create_app() -> Flask:
|
||||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||||
|
|
||||||
for handler in logging.root.handlers:
|
for handler in logging.root.handlers:
|
||||||
|
assert handler.formatter
|
||||||
handler.formatter.converter = time_converter
|
handler.formatter.converter = time_converter
|
||||||
initialize_extensions(app)
|
initialize_extensions(app)
|
||||||
register_blueprints(app)
|
register_blueprints(app)
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from libs.password import hash_password, password_pattern, valid_password
|
from libs.password import hash_password, password_pattern, valid_password
|
||||||
from libs.rsa import generate_key_pair
|
from libs.rsa import generate_key_pair
|
||||||
from models.account import Tenant
|
from models import Tenant
|
||||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||||
|
|
@ -449,14 +449,14 @@ def convert_to_agent_apps():
|
||||||
# fetch first 1000 apps
|
# fetch first 1000 apps
|
||||||
sql_query = """SELECT a.id AS id FROM apps a
|
sql_query = """SELECT a.id AS id FROM apps a
|
||||||
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
|
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
|
||||||
WHERE a.mode = 'chat'
|
WHERE a.mode = 'chat'
|
||||||
AND am.agent_mode is not null
|
AND am.agent_mode is not null
|
||||||
AND (
|
AND (
|
||||||
am.agent_mode like '%"strategy": "function_call"%'
|
am.agent_mode like '%"strategy": "function_call"%'
|
||||||
OR am.agent_mode like '%"strategy": "react"%'
|
OR am.agent_mode like '%"strategy": "react"%'
|
||||||
)
|
)
|
||||||
AND (
|
AND (
|
||||||
am.agent_mode like '{"enabled": true%'
|
am.agent_mode like '{"enabled": true%'
|
||||||
OR am.agent_mode like '{"max_iteration": %'
|
OR am.agent_mode like '{"max_iteration": %'
|
||||||
) ORDER BY a.created_at DESC LIMIT 1000
|
) ORDER BY a.created_at DESC LIMIT 1000
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Literal, Optional
|
||||||
|
|
||||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
@ -11,11 +11,11 @@ class SecurityConfig(BaseSettings):
|
||||||
Security-related configurations for the application
|
Security-related configurations for the application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SECRET_KEY: Optional[str] = Field(
|
SECRET_KEY: str = Field(
|
||||||
description="Secret key for secure session cookie signing."
|
description="Secret key for secure session cookie signing."
|
||||||
"Make sure you are changing this key for your deployment with a strong key."
|
"Make sure you are changing this key for your deployment with a strong key."
|
||||||
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
||||||
default=None,
|
default="",
|
||||||
)
|
)
|
||||||
|
|
||||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||||
|
|
@ -177,6 +177,16 @@ class FileUploadConfig(BaseSettings):
|
||||||
default=10,
|
default=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||||
|
description="video file size limit in Megabytes for uploading files",
|
||||||
|
default=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||||
|
description="audio file size limit in Megabytes for uploading files",
|
||||||
|
default=50,
|
||||||
|
)
|
||||||
|
|
||||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||||
description="Maximum number of files allowed in a batch upload operation",
|
description="Maximum number of files allowed in a batch upload operation",
|
||||||
default=20,
|
default=20,
|
||||||
|
|
@ -349,8 +359,8 @@ class WorkflowConfig(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||||
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
|
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||||
default=5 * 1024,
|
default=200 * 1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -468,6 +478,7 @@ class RagEtlConfig(BaseSettings):
|
||||||
Configuration for RAG ETL processes
|
Configuration for RAG ETL processes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
|
||||||
ETL_TYPE: str = Field(
|
ETL_TYPE: str = Field(
|
||||||
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
|
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
|
||||||
default="dify",
|
default="dify",
|
||||||
|
|
@ -529,7 +540,7 @@ class IndexingConfig(BaseSettings):
|
||||||
|
|
||||||
|
|
||||||
class ImageFormatConfig(BaseSettings):
|
class ImageFormatConfig(BaseSettings):
|
||||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||||
default="base64",
|
default="base64",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||||
|
|
||||||
CURRENT_VERSION: str = Field(
|
CURRENT_VERSION: str = Field(
|
||||||
description="Dify version",
|
description="Dify version",
|
||||||
default="0.9.1",
|
default="0.10.0-beta1",
|
||||||
)
|
)
|
||||||
|
|
||||||
COMMIT_SHA: str = Field(
|
COMMIT_SHA: str = Field(
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,21 @@
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||||
|
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||||
|
|
||||||
|
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||||
|
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||||
|
|
||||||
|
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||||
|
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||||
|
|
||||||
|
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||||
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
|
|
||||||
|
if dify_config.ETL_TYPE == "Unstructured":
|
||||||
|
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||||
|
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
|
||||||
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||||
|
|
||||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,8 @@ from fields.conversation_fields import (
|
||||||
)
|
)
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
class CompletionConversationApi(Resource):
|
class CompletionConversationApi(Resource):
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_site_fields
|
from fields.app_fields import app_site_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import Site
|
from models import Site
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
|
|
|
||||||
|
|
@ -13,14 +13,14 @@ from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.segments import factory
|
from factories import variable_factory
|
||||||
from core.errors.error import AppInvokeQuotaExceededError
|
|
||||||
from fields.workflow_fields import workflow_fields
|
from fields.workflow_fields import workflow_fields
|
||||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
from services.app_dsl_service import AppDslService
|
from services.app_dsl_service import AppDslService
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
|
|
@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
environment_variables_list = args.get("environment_variables") or []
|
environment_variables_list = args.get("environment_variables") or []
|
||||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
environment_variables = [
|
||||||
|
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
|
||||||
|
]
|
||||||
conversation_variables_list = args.get("conversation_variables") or []
|
conversation_variables_list = args.get("conversation_variables") or []
|
||||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
conversation_variables = [
|
||||||
|
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||||
|
]
|
||||||
workflow = workflow_service.sync_draft_workflow(
|
workflow = workflow_service.sync_draft_workflow(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
graph=args["graph"],
|
graph=args["graph"],
|
||||||
|
|
@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource):
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
response = AppGenerateService.generate(
|
||||||
response = AppGenerateService.generate(
|
app_model=app_model,
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
user=current_user,
|
||||||
)
|
args=args,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("internal server error.")
|
|
||||||
raise InternalServerError()
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowTaskStopApi(Resource):
|
class WorkflowTaskStopApi(Resource):
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,8 @@ from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
from services.workflow_app_service import WorkflowAppService
|
from services.workflow_app_service import WorkflowAppService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,8 @@ from fields.workflow_run_fields import (
|
||||||
)
|
)
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
from services.workflow_run_service import WorkflowRunService
|
from services.workflow_run_service import WorkflowRunService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,11 @@ from controllers.console import api
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from enums import WorkflowRunTriggeredFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from models.workflow import WorkflowRunTriggeredFrom
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDailyRunsStatistic(Resource):
|
class WorkflowDailyRunsStatistic(Resource):
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ from typing import Optional, Union
|
||||||
from controllers.console.app.error import AppNotFoundError
|
from controllers.console.app.error import AppNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from controllers.console.setup import setup_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.errors.account import RateLimitExceededError
|
from services.errors.account import RateLimitExceededError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from libs.helper import email, get_remote_ip
|
from libs.helper import email, get_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,8 @@ from constants.languages import languages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import get_remote_ip
|
from libs.helper import get_remote_ip
|
||||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||||
from models.account import Account, AccountStatus
|
from models import Account
|
||||||
|
from models.account import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
|
|
||||||
from .. import api
|
from .. import api
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import Document
|
from models import DataSourceOauthBinding, Document
|
||||||
from models.source import DataSourceOauthBinding
|
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
|
||||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||||
from fields.document_fields import document_status_fields
|
from fields.document_fields import document_status_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||||
from models.model import ApiToken, UploadFile
|
from models.dataset import DatasetPermissionEnum
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,7 @@ from fields.document_fields import (
|
||||||
document_with_segments_fields,
|
document_with_segments_fields,
|
||||||
)
|
)
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
|
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||||
from models.model import UploadFile
|
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.segment_fields import segment_fields
|
from fields.segment_fields import segment_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import DocumentSegment
|
from models import DocumentSegment
|
||||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal_with
|
from flask_restful import Resource, marshal_with
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from constants import DOCUMENT_EXTENSIONS
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.datasets.error import (
|
from controllers.console.datasets.error import (
|
||||||
FileTooLargeError,
|
FileTooLargeError,
|
||||||
|
|
@ -13,9 +16,10 @@ from controllers.console.datasets.error import (
|
||||||
)
|
)
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from core.helper import ssrf_proxy
|
||||||
|
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
PREVIEW_WORDS_LIMIT = 3000
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
||||||
|
|
@ -51,7 +55,7 @@ class FileApi(Resource):
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(file, current_user)
|
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
|
@ -75,11 +79,24 @@ class FileSupportTypeApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
etl_type = dify_config.ETL_TYPE
|
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
|
||||||
return {"allowed_extensions": allowed_extensions}
|
|
||||||
|
class RemoteFileInfoApi(Resource):
|
||||||
|
@marshal_with(remote_file_info_fields)
|
||||||
|
def get(self, url):
|
||||||
|
decoded_url = urllib.parse.unquote(url)
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.head(decoded_url)
|
||||||
|
return {
|
||||||
|
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||||
|
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, "/files/upload")
|
api.add_resource(FileApi, "/files/upload")
|
||||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||||
|
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.installed_app_fields import installed_app_list_fields
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models import App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import InstalledApp
|
from models import InstalledApp
|
||||||
|
|
||||||
|
|
||||||
def installed_app_required(view=None):
|
def installed_app_required(view=None):
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from extensions.ext_database import db
|
||||||
from fields.member_fields import account_fields
|
from fields.member_fields import account_fields
|
||||||
from libs.helper import TimestampField, timezone
|
from libs.helper import TimestampField, timezone
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.account import AccountIntegrate, InvitationCode
|
from models import AccountIntegrate, InvitationCode
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
return WorkflowToolManageService.create_workflow_tool(
|
return WorkflowToolManageService.create_workflow_tool(
|
||||||
user_id,
|
user_id=user_id,
|
||||||
tenant_id,
|
tenant_id=tenant_id,
|
||||||
args["workflow_app_id"],
|
workflow_app_id=args["workflow_app_id"],
|
||||||
args["name"],
|
name=args["name"],
|
||||||
args["label"],
|
label=args["label"],
|
||||||
args["icon"],
|
icon=args["icon"],
|
||||||
args["description"],
|
description=args["description"],
|
||||||
args["parameters"],
|
parameters=args["parameters"],
|
||||||
args["privacy_policy"],
|
privacy_policy=args["privacy_policy"],
|
||||||
args.get("labels", []),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(file, current_user, True)
|
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||||
|
|
||||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,36 @@ class ImagePreviewApi(Resource):
|
||||||
return {"content": "Invalid request."}, 400
|
return {"content": "Invalid request."}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
|
generator, mimetype = FileService.get_image_preview(
|
||||||
|
file_id=file_id,
|
||||||
|
timestamp=timestamp,
|
||||||
|
nonce=nonce,
|
||||||
|
sign=sign,
|
||||||
|
)
|
||||||
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
return Response(generator, mimetype=mimetype)
|
||||||
|
|
||||||
|
|
||||||
|
class FilePreviewApi(Resource):
|
||||||
|
def get(self, file_id):
|
||||||
|
file_id = str(file_id)
|
||||||
|
|
||||||
|
timestamp = request.args.get("timestamp")
|
||||||
|
nonce = request.args.get("nonce")
|
||||||
|
sign = request.args.get("sign")
|
||||||
|
|
||||||
|
if not timestamp or not nonce or not sign:
|
||||||
|
return {"content": "Invalid request."}, 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator, mimetype = FileService.get_signed_file_preview(
|
||||||
|
file_id=file_id,
|
||||||
|
timestamp=timestamp,
|
||||||
|
nonce=nonce,
|
||||||
|
sign=sign,
|
||||||
|
)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
|
@ -49,6 +78,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
|
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
|
||||||
|
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
|
||||||
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
|
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import marshal_with
|
from flask_restful import marshal_with
|
||||||
|
|
||||||
|
|
@ -5,7 +7,8 @@ import services
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from fields.file_fields import file_fields
|
from core.helper import ssrf_proxy
|
||||||
|
from fields.file_fields import file_fields, remote_file_info_fields
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,4 +34,19 @@ class FileApi(WebApiResource):
|
||||||
return upload_file, 201
|
return upload_file, 201
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileInfoApi(WebApiResource):
|
||||||
|
@marshal_with(remote_file_info_fields)
|
||||||
|
def get(self, url):
|
||||||
|
decoded_url = urllib.parse.unquote(url)
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.head(decoded_url)
|
||||||
|
return {
|
||||||
|
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||||
|
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, "/files/upload")
|
api.add_resource(FileApi, "/files/upload")
|
||||||
|
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from fields.message_fields import agent_thought_fields
|
from fields.message_fields import agent_thought_fields
|
||||||
|
from fields.raws import FilesContainedField
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
@ -58,7 +59,7 @@ class MessageListApi(WebApiResource):
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"conversation_id": fields.String,
|
"conversation_id": fields.String,
|
||||||
"parent_message_id": fields.String,
|
"parent_message_id": fields.String,
|
||||||
"inputs": fields.Raw,
|
"inputs": FilesContainedField,
|
||||||
"query": fields.String,
|
"query": fields.String,
|
||||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import (
|
||||||
)
|
)
|
||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.file.message_file_parser import MessageFileParser
|
from core.file import file_manager
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities import (
|
||||||
from core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
LLMUsage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
|
@ -40,8 +41,8 @@ from core.tools.entities.tool_entities import (
|
||||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models.model import Conversation, Message, MessageAgentThought
|
from models.model import Conversation, Message, MessageAgentThought
|
||||||
from models.tools import ToolConversationVariables
|
from models.tools import ToolConversationVariables
|
||||||
|
|
||||||
|
|
@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner):
|
||||||
db_variables: Optional[ToolConversationVariables] = None,
|
db_variables: Optional[ToolConversationVariables] = None,
|
||||||
model_instance: ModelInstance = None,
|
model_instance: ModelInstance = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Agent runner
|
|
||||||
:param tenant_id: tenant id
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param conversation: conversation
|
|
||||||
:param app_config: app generate entity
|
|
||||||
:param model_config: model config
|
|
||||||
:param config: dataset config
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param message: message
|
|
||||||
:param user_id: user id
|
|
||||||
:param memory: memory
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param variables_pool: variables pool
|
|
||||||
:param db_variables: db variables
|
|
||||||
:param model_instance: model instance
|
|
||||||
"""
|
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
|
|
@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner):
|
||||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
parameter_type = parameter.type.as_normal_type()
|
||||||
enum = []
|
enum = []
|
||||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||||
enum = [option.value for option in parameter.options]
|
enum = [option.value for option in parameter.options]
|
||||||
|
|
@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner):
|
||||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
parameter_type = parameter.type.as_normal_type()
|
||||||
enum = []
|
enum = []
|
||||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||||
enum = [option.value for option in parameter.options]
|
enum = [option.value for option in parameter.options]
|
||||||
|
|
@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||||
message_file_parser = MessageFileParser(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_config.app_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
files = message.message_files
|
files = message.message_files
|
||||||
if files:
|
if files:
|
||||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||||
|
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
file_objs = file_factory.build_from_message_files(
|
||||||
|
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
if not file_objs:
|
if not file_objs:
|
||||||
return UserPromptMessage(content=message.query)
|
return UserPromptMessage(content=message.query)
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
return UserPromptMessage(content=prompt_message_contents)
|
return UserPromptMessage(content=prompt_message_contents)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from core.agent.cot_agent_runner import CotAgentRunner
|
from core.agent.cot_agent_runner import CotAgentRunner
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.file import file_manager
|
||||||
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
|
|
@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file_obj in self.files:
|
for file_obj in self.files:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,15 @@ from typing import Any, Union
|
||||||
from core.agent.base_agent_runner import BaseAgentRunner
|
from core.agent.base_agent_runner import BaseAgentRunner
|
||||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.file import file_manager
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
LLMUsage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
|
@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file_obj in self.files:
|
for file_obj in self.files:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
|
||||||
VariableEntity(
|
VariableEntity(
|
||||||
type=variable_type,
|
type=variable_type,
|
||||||
variable=variable.get("variable"),
|
variable=variable.get("variable"),
|
||||||
description=variable.get("description"),
|
description=variable.get("description", ""),
|
||||||
label=variable.get("label"),
|
label=variable.get("label"),
|
||||||
required=variable.get("required", False),
|
required=variable.get("required", False),
|
||||||
max_length=variable.get("max_length"),
|
max_length=variable.get("max_length"),
|
||||||
options=variable.get("options"),
|
options=variable.get("options", []),
|
||||||
default=variable.get("default"),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.file.file_obj import FileExtraConfig
|
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from models import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigEntity(BaseModel):
|
class ModelConfigEntity(BaseModel):
|
||||||
|
|
@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
|
||||||
ADVANCED = "advanced"
|
ADVANCED = "advanced"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "PromptType":
|
def value_of(cls, value: str):
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|
@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
|
||||||
PARAGRAPH = "paragraph"
|
PARAGRAPH = "paragraph"
|
||||||
NUMBER = "number"
|
NUMBER = "number"
|
||||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||||
|
FILE = "file"
|
||||||
|
FILE_LIST = "file-list"
|
||||||
|
|
||||||
|
|
||||||
class VariableEntity(BaseModel):
|
class VariableEntity(BaseModel):
|
||||||
|
|
@ -102,13 +105,14 @@ class VariableEntity(BaseModel):
|
||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
label: str
|
label: str
|
||||||
description: Optional[str] = None
|
description: str = ""
|
||||||
type: VariableEntityType
|
type: VariableEntityType
|
||||||
required: bool = False
|
required: bool = False
|
||||||
max_length: Optional[int] = None
|
max_length: Optional[int] = None
|
||||||
options: Optional[list[str]] = None
|
options: Sequence[str] = Field(default_factory=list)
|
||||||
default: Optional[str] = None
|
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||||
hint: Optional[str] = None
|
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||||
|
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
|
|
@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||||
MULTIPLE = "multiple"
|
MULTIPLE = "multiple"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
def value_of(cls, value: str):
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from core.file.file_obj import FileExtraConfig
|
from core.file.models import FileExtraConfig
|
||||||
|
from models import FileUploadConfig
|
||||||
|
|
||||||
|
|
||||||
class FileUploadConfigManager:
|
class FileUploadConfigManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
|
||||||
"""
|
"""
|
||||||
Convert model config to model config
|
Convert model config to model config
|
||||||
|
|
||||||
|
|
@ -15,19 +16,18 @@ class FileUploadConfigManager:
|
||||||
"""
|
"""
|
||||||
file_upload_dict = config.get("file_upload")
|
file_upload_dict = config.get("file_upload")
|
||||||
if file_upload_dict:
|
if file_upload_dict:
|
||||||
if file_upload_dict.get("image"):
|
if file_upload_dict.get("enabled"):
|
||||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
data = {
|
||||||
image_config = {
|
"image_config": {
|
||||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
"number_limits": file_upload_dict["number_limits"],
|
||||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
"transfer_methods": file_upload_dict["allowed_file_upload_methods"],
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if is_vision:
|
if is_vision:
|
||||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
||||||
|
|
||||||
return FileExtraConfig(image_config=image_config)
|
return FileExtraConfig.model_validate(data)
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
||||||
|
|
@ -39,29 +39,7 @@ class FileUploadConfigManager:
|
||||||
"""
|
"""
|
||||||
if not config.get("file_upload"):
|
if not config.get("file_upload"):
|
||||||
config["file_upload"] = {}
|
config["file_upload"] = {}
|
||||||
|
else:
|
||||||
if not isinstance(config["file_upload"], dict):
|
FileUploadConfig.model_validate(config["file_upload"])
|
||||||
raise ValueError("file_upload must be of dict type")
|
|
||||||
|
|
||||||
# check image config
|
|
||||||
if not config["file_upload"].get("image"):
|
|
||||||
config["file_upload"]["image"] = {"enabled": False}
|
|
||||||
|
|
||||||
if config["file_upload"]["image"]["enabled"]:
|
|
||||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
|
||||||
if number_limits < 1 or number_limits > 6:
|
|
||||||
raise ValueError("number_limits must be in [1, 6]")
|
|
||||||
|
|
||||||
if is_vision:
|
|
||||||
detail = config["file_upload"]["image"]["detail"]
|
|
||||||
if detail not in {"high", "low"}:
|
|
||||||
raise ValueError("detail must be in ['high', 'low']")
|
|
||||||
|
|
||||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
|
||||||
if not isinstance(transfer_methods, list):
|
|
||||||
raise ValueError("transfer_methods must be of list type")
|
|
||||||
for method in transfer_methods:
|
|
||||||
if method not in {"remote_url", "local_file"}:
|
|
||||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
|
||||||
|
|
||||||
return config, ["file_upload"]
|
return config, ["file_upload"]
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
|
||||||
|
|
||||||
# variables
|
# variables
|
||||||
for variable in user_input_form:
|
for variable in user_input_form:
|
||||||
variables.append(VariableEntity(**variable))
|
variables.append(VariableEntity.model_validate(variable))
|
||||||
|
|
||||||
return variables
|
return variables
|
||||||
|
|
|
||||||
|
|
@ -20,10 +20,11 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
@ -95,10 +96,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
|
@ -106,8 +113,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
trace_manager = TraceQueueManager(
|
||||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||||
|
)
|
||||||
|
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
|
|
@ -118,7 +126,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs
|
||||||
|
if conversation
|
||||||
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id"),
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,26 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.app_invoke_entities import (
|
|
||||||
AdvancedChatAppGenerateEntity,
|
|
||||||
InvokeFrom,
|
|
||||||
)
|
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
QueueAnnotationReplyEvent,
|
QueueAnnotationReplyEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
)
|
)
|
||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
from core.workflow.entities.node_entities import UserFrom
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import UserFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
from models.workflow import ConversationVariable, WorkflowType
|
from models.workflow import ConversationVariable, WorkflowType
|
||||||
|
|
@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: application queue manager
|
|
||||||
:param conversation: conversation
|
|
||||||
:param message: message
|
|
||||||
"""
|
|
||||||
super().__init__(queue_manager)
|
super().__init__(queue_manager)
|
||||||
|
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
|
|
@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""
|
|
||||||
Run application
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
app_config = self.application_generate_entity.app_config
|
app_config = self.application_generate_entity.app_config
|
||||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||||
|
|
||||||
|
|
@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
user_id = self.application_generate_entity.user_id
|
user_id = self.application_generate_entity.user_id
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
if dify_config.DEBUG:
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
if self.application_generate_entity.single_iteration_run:
|
if self.application_generate_entity.single_iteration_run:
|
||||||
|
|
@ -198,15 +184,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
query: str,
|
query: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
|
||||||
Handle input moderation
|
|
||||||
:param app_record: app record
|
|
||||||
:param app_generate_entity: application generate entity
|
|
||||||
:param inputs: inputs
|
|
||||||
:param query: query
|
|
||||||
:param message_id: message id
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# process sensitive_word_avoidance
|
# process sensitive_word_avoidance
|
||||||
_, inputs, query = self.moderation_for_inputs(
|
_, inputs, query = self.moderation_for_inputs(
|
||||||
|
|
@ -226,14 +203,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
def handle_annotation_reply(
|
def handle_annotation_reply(
|
||||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
|
||||||
Handle annotation reply
|
|
||||||
:param app_record: app record
|
|
||||||
:param message: message
|
|
||||||
:param query: query
|
|
||||||
:param app_generate_entity: application generate entity
|
|
||||||
"""
|
|
||||||
# annotation reply
|
|
||||||
annotation_reply = self.query_app_annotations_to_reply(
|
annotation_reply = self.query_app_annotations_to_reply(
|
||||||
app_record=app_record,
|
app_record=app_record,
|
||||||
message=message,
|
message=message,
|
||||||
|
|
@ -255,8 +224,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||||
"""
|
"""
|
||||||
Direct output
|
Direct output
|
||||||
:param text: text
|
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
self._publish_event(QueueTextChunkEvent(text=text))
|
self._publish_event(QueueTextChunkEvent(text=text))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
|
|
@ -49,6 +49,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from enums.workflow_nodes import NodeType
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
|
|
@ -112,6 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
|
|
||||||
self._conversation_name_generate_thread = None
|
self._conversation_name_generate_thread = None
|
||||||
|
self._recorded_files: list[Mapping[str, Any]] = []
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -290,6 +292,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||||
|
|
||||||
|
# Record files if it's an answer node or end node
|
||||||
|
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||||
|
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||||
|
|
||||||
response = self._workflow_node_finish_to_stream_response(
|
response = self._workflow_node_finish_to_stream_response(
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
|
@ -356,7 +362,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
total_tokens=graph_runtime_state.total_tokens,
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
total_steps=graph_runtime_state.node_run_steps,
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
outputs=event.outputs,
|
||||||
conversation_id=self._conversation.id,
|
conversation_id=self._conversation.id,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
|
|
@ -528,7 +534,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
del extras["metadata"]["annotation_reply"]
|
del extras["metadata"]["annotation_reply"]
|
||||||
|
|
||||||
return MessageEndStreamResponse(
|
return MessageEndStreamResponse(
|
||||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -17,12 +17,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from factories import file_factory
|
||||||
from models.model import App, EndUser
|
from models import Account, App, EndUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -49,7 +49,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
) -> dict: ...
|
) -> dict: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
self,
|
||||||
|
app_model: App,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Any,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
stream: bool = True,
|
||||||
) -> Union[dict, Generator[dict, None, None]]:
|
) -> Union[dict, Generator[dict, None, None]]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
@ -97,12 +102,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||||
|
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args.get("files") or []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
|
@ -115,8 +127,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
)
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = AgentChatAppGenerateEntity(
|
application_generate_entity = AgentChatAppGenerateEntity(
|
||||||
|
|
@ -124,7 +135,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs
|
||||||
|
if conversation
|
||||||
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id"),
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,92 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
|
from core.file import File, FileExtraConfig
|
||||||
|
from factories import file_factory
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||||
|
from enums import CreatedByRole
|
||||||
|
|
||||||
|
|
||||||
class BaseAppGenerator:
|
class BaseAppGenerator:
|
||||||
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
def _prepare_user_inputs(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_inputs: Optional[Mapping[str, Any]],
|
||||||
|
app_config: "AppConfig",
|
||||||
|
user_id: str,
|
||||||
|
role: "CreatedByRole",
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
variables = app_config.variables
|
variables = app_config.variables
|
||||||
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||||
return filtered_inputs
|
# Convert files in inputs to File
|
||||||
|
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||||
|
# Convert single file to File
|
||||||
|
files_inputs = {
|
||||||
|
k: file_factory.build_from_mapping(
|
||||||
|
mapping=v,
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=role,
|
||||||
|
config=FileExtraConfig(
|
||||||
|
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||||
|
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
|
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for k, v in user_inputs.items()
|
||||||
|
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||||
|
}
|
||||||
|
# Convert list of files to File
|
||||||
|
file_list_inputs = {
|
||||||
|
k: file_factory.build_from_mappings(
|
||||||
|
mappings=v,
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=role,
|
||||||
|
config=FileExtraConfig(
|
||||||
|
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||||
|
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
|
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for k, v in user_inputs.items()
|
||||||
|
if isinstance(v, list)
|
||||||
|
# Ensure skip List<File>
|
||||||
|
and all(isinstance(item, dict) for item in v)
|
||||||
|
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
|
||||||
|
}
|
||||||
|
# Merge all inputs
|
||||||
|
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
|
||||||
|
|
||||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
# Check if all files are converted to File
|
||||||
user_input_value = inputs.get(var.variable)
|
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
|
||||||
if var.required and not user_input_value:
|
raise ValueError("Invalid input type")
|
||||||
raise ValueError(f"{var.variable} is required in input form")
|
if any(
|
||||||
if not var.required and not user_input_value:
|
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
|
||||||
# TODO: should we return None here if the default value is None?
|
|
||||||
return var.default or ""
|
|
||||||
if (
|
|
||||||
var.type
|
|
||||||
in {
|
|
||||||
VariableEntityType.TEXT_INPUT,
|
|
||||||
VariableEntityType.SELECT,
|
|
||||||
VariableEntityType.PARAGRAPH,
|
|
||||||
}
|
|
||||||
and user_input_value
|
|
||||||
and not isinstance(user_input_value, str)
|
|
||||||
):
|
):
|
||||||
|
raise ValueError("Invalid input type")
|
||||||
|
|
||||||
|
return user_inputs
|
||||||
|
|
||||||
|
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||||
|
user_input_value = inputs.get(var.variable)
|
||||||
|
if not user_input_value:
|
||||||
|
if var.required:
|
||||||
|
raise ValueError(f"{var.variable} is required in input form")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if var.type in {
|
||||||
|
VariableEntityType.TEXT_INPUT,
|
||||||
|
VariableEntityType.SELECT,
|
||||||
|
VariableEntityType.PARAGRAPH,
|
||||||
|
} and not isinstance(user_input_value, str):
|
||||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||||
# may raise ValueError if user_input_value is not a valid number
|
# may raise ValueError if user_input_value is not a valid number
|
||||||
|
|
@ -41,12 +98,24 @@ class BaseAppGenerator:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||||
if var.type == VariableEntityType.SELECT:
|
if var.type == VariableEntityType.SELECT:
|
||||||
options = var.options or []
|
options = var.options
|
||||||
if user_input_value not in options:
|
if user_input_value not in options:
|
||||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
if var.max_length and len(user_input_value) > var.max_length:
|
||||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||||
|
elif var.type == VariableEntityType.FILE:
|
||||||
|
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||||
|
raise ValueError(f"{var.variable} in input form must be a file")
|
||||||
|
elif var.type == VariableEntityType.FILE_LIST:
|
||||||
|
if not (
|
||||||
|
isinstance(user_input_value, list)
|
||||||
|
and (
|
||||||
|
all(isinstance(item, dict) for item in user_input_value)
|
||||||
|
or all(isinstance(item, File) for item in user_input_value)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||||
|
|
||||||
return user_input_value
|
return user_input_value
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||||
from models.model import App, AppMode, Message, MessageAnnotation
|
from models.model import App, AppMode, Message, MessageAnnotation
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
class AppRunner:
|
class AppRunner:
|
||||||
|
|
@ -37,7 +37,7 @@ class AppRunner:
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -137,7 +137,7 @@ class AppRunner:
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
memory: Optional[TokenBufferMemory] = None,
|
memory: Optional[TokenBufferMemory] = None,
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,11 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, EndUser
|
from models.model import App, EndUser
|
||||||
|
|
||||||
|
|
@ -99,12 +100,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||||
|
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
|
@ -117,7 +125,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
)
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
trace_manager = TraceQueueManager(app_model.id)
|
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = ChatAppGenerateEntity(
|
application_generate_entity = ChatAppGenerateEntity(
|
||||||
|
|
@ -125,15 +133,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs
|
||||||
|
if conversation
|
||||||
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id"),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
extras=extras,
|
extras=extras,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
# init generate records
|
# init generate records
|
||||||
|
|
|
||||||
|
|
@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from factories import file_factory
|
||||||
from models.model import App, EndUser, Message
|
from models import Account, App, EndUser, Message
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
|
|
||||||
|
|
@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
|
@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
)
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
|
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||||
trace_manager = TraceQueueManager(app_model.id)
|
trace_manager = TraceQueueManager(app_model.id)
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
|
|
@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
|
@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
override_model_config_dict["model"] = model_dict
|
override_model_config_dict["model"] = model_dict
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=message.files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from core.app.entities.task_entities import (
|
||||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||||
|
|
@ -238,7 +238,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||||
type=file.type.value,
|
type=file.type.value,
|
||||||
transfer_method=file.transfer_method.value,
|
transfer_method=file.transfer_method.value,
|
||||||
belongs_to="user",
|
belongs_to="user",
|
||||||
url=file.url,
|
url=file.remote_url,
|
||||||
upload_file_id=file.related_id,
|
upload_file_id=file.related_id,
|
||||||
created_by_role=("account" if account_id else "end_user"),
|
created_by_role=("account" if account_id else "end_user"),
|
||||||
created_by=account_id or end_user_id,
|
created_by=account_id or end_user_id,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, Union, overload
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from factories import file_factory
|
||||||
from models.model import App, EndUser
|
from models import Account, App, EndUser, Workflow
|
||||||
from models.workflow import Workflow
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -63,48 +62,45 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
args: dict,
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
Generate App response.
|
|
||||||
|
|
||||||
:param app_model: App
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
:param workflow: Workflow
|
|
||||||
:param user: account or end user
|
|
||||||
:param args: request args
|
|
||||||
:param invoke_from: invoke from source
|
|
||||||
:param stream: is stream
|
|
||||||
:param call_depth: call depth
|
|
||||||
:param workflow_thread_pool_id: workflow thread pool id
|
|
||||||
"""
|
|
||||||
inputs = args["inputs"]
|
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
if file_extra_config:
|
system_files = file_factory.build_from_mappings(
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
mappings=files,
|
||||||
else:
|
tenant_id=app_model.tenant_id,
|
||||||
file_objs = []
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
app_config = WorkflowAppConfigManager.get_app_config(
|
||||||
|
app_model=app_model,
|
||||||
|
workflow=workflow,
|
||||||
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
trace_manager = TraceQueueManager(
|
||||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
app_id=app_model.id,
|
||||||
|
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs: Mapping[str, Any] = args["inputs"]
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = WorkflowAppGenerateEntity(
|
application_generate_entity = WorkflowAppGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
files=file_objs,
|
files=system_files,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,19 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
from core.workflow.entities.node_entities import UserFrom
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import UserFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, EndUser
|
from models.model import App, EndUser
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
if dify_config.DEBUG:
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
@ -328,9 +327,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
total_tokens=graph_runtime_state.total_tokens,
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
total_steps=graph_runtime_state.node_run_steps,
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
outputs=json.dumps(event.outputs)
|
outputs=event.outputs,
|
||||||
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
|
|
||||||
else None,
|
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.entities.node_entities import NodeType
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
|
|
@ -45,6 +44,7 @@ from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||||
from core.workflow.nodes.node_mapping import node_classes
|
from core.workflow.nodes.node_mapping import node_classes
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import NodeType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
|
||||||
|
|
@ -22,7 +22,7 @@ class InvokeFrom(Enum):
|
||||||
DEBUGGER = "debugger"
|
DEBUGGER = "debugger"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "InvokeFrom":
|
def value_of(cls, value: str):
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|
@ -81,7 +81,7 @@ class AppGenerateEntity(BaseModel):
|
||||||
app_config: AppConfig
|
app_config: AppConfig
|
||||||
|
|
||||||
inputs: Mapping[str, Any]
|
inputs: Mapping[str, Any]
|
||||||
files: list[FileVar] = []
|
files: Sequence[File]
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
# extras
|
# extras
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,9 @@ from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from enums import NodeType
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(str, Enum):
|
class QueueEvent(str, Enum):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections.abc import Mapping
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||||
id: str
|
id: str
|
||||||
metadata: dict = {}
|
metadata: dict = {}
|
||||||
|
files: Optional[list[Mapping[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
class MessageFileStreamResponse(StreamResponse):
|
class MessageFileStreamResponse(StreamResponse):
|
||||||
|
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
import re
|
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
|
|
||||||
from . import SegmentGroup, factory
|
|
||||||
|
|
||||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
|
||||||
parts = re.split(VARIABLE_PATTERN, template)
|
|
||||||
segments = []
|
|
||||||
for part in filter(lambda x: x, parts):
|
|
||||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
|
||||||
segments.append(value)
|
|
||||||
else:
|
|
||||||
segments.append(factory.build_segment(part))
|
|
||||||
return SegmentGroup(value=segments)
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
|
|
@ -27,15 +28,15 @@ from core.app.entities.task_entities import (
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
)
|
)
|
||||||
from core.file.file_obj import FileVar
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.workflow.entities.node_entities import NodeType
|
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import NodeType, WorkflowRunTriggeredFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
@ -47,7 +48,6 @@ from models.workflow import (
|
||||||
WorkflowNodeExecutionTriggeredFrom,
|
WorkflowNodeExecutionTriggeredFrom,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
WorkflowRunTriggeredFrom,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,7 +113,7 @@ class WorkflowCycleManage:
|
||||||
start_at: float,
|
start_at: float,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
total_steps: int,
|
total_steps: int,
|
||||||
outputs: Optional[str] = None,
|
outputs: Mapping[str, Any] | None = None,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
trace_manager: Optional[TraceQueueManager] = None,
|
trace_manager: Optional[TraceQueueManager] = None,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
|
|
@ -129,8 +129,10 @@ class WorkflowCycleManage:
|
||||||
"""
|
"""
|
||||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||||
|
|
||||||
|
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||||
|
|
||||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||||
workflow_run.outputs = outputs
|
workflow_run.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||||
workflow_run.total_tokens = total_tokens
|
workflow_run.total_tokens = total_tokens
|
||||||
workflow_run.total_steps = total_steps
|
workflow_run.total_steps = total_steps
|
||||||
|
|
@ -260,10 +262,11 @@ class WorkflowCycleManage:
|
||||||
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_node_execution.execution_metadata = (
|
workflow_node_execution.execution_metadata = (
|
||||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||||
|
|
@ -287,12 +290,13 @@ class WorkflowCycleManage:
|
||||||
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
workflow_node_execution.error = event.error
|
workflow_node_execution.error = event.error
|
||||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||||
|
|
||||||
|
|
@ -609,11 +613,11 @@ class WorkflowCycleManage:
|
||||||
if not outputs_dict:
|
if not outputs_dict:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
files = []
|
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||||
for output_var, output_value in outputs_dict.items():
|
# Remove None
|
||||||
file_vars = self._fetch_files_from_variable_value(output_value)
|
files = [file for file in files if file]
|
||||||
if file_vars:
|
# Flatten list
|
||||||
files.extend(file_vars)
|
files = [file for sublist in files for file in sublist]
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
@ -639,7 +643,7 @@ class WorkflowCycleManage:
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
|
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, str | int | None] | None:
|
||||||
"""
|
"""
|
||||||
Get file var from value
|
Get file var from value
|
||||||
:param value: variable value
|
:param value: variable value
|
||||||
|
|
@ -648,14 +652,11 @@ class WorkflowCycleManage:
|
||||||
if not value:
|
if not value:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict) and value.get("model_identity") == FILE_MODEL_IDENTITY:
|
||||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
return value
|
||||||
return value
|
elif isinstance(value, File):
|
||||||
elif isinstance(value, FileVar):
|
|
||||||
return value.to_dict()
|
return value.to_dict()
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||||
"""
|
"""
|
||||||
Refetch workflow run
|
Refetch workflow run
|
||||||
|
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
import enum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageFileType(enum.Enum):
|
|
||||||
IMAGE = "image"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in PromptMessageFileType:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageFile(BaseModel):
|
|
||||||
type: PromptMessageFileType
|
|
||||||
data: Any = None
|
|
||||||
|
|
||||||
|
|
||||||
class ImagePromptMessageFile(PromptMessageFile):
|
|
||||||
class DETAIL(enum.Enum):
|
|
||||||
LOW = "low"
|
|
||||||
HIGH = "high"
|
|
||||||
|
|
||||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
|
||||||
detail: DETAIL = DETAIL.LOW
|
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
from .constants import FILE_MODEL_IDENTITY
|
||||||
|
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||||
|
from .models import (
|
||||||
|
File,
|
||||||
|
FileExtraConfig,
|
||||||
|
ImageConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FileType",
|
||||||
|
"FileExtraConfig",
|
||||||
|
"FileTransferMethod",
|
||||||
|
"FileBelongsTo",
|
||||||
|
"File",
|
||||||
|
"ImageConfig",
|
||||||
|
"FileAttribute",
|
||||||
|
"ArrayFileAttribute",
|
||||||
|
"FILE_MODEL_IDENTITY",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class FileType(str, Enum):
|
||||||
|
IMAGE = "image"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileType:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileTransferMethod(str, Enum):
|
||||||
|
REMOTE_URL = "remote_url"
|
||||||
|
LOCAL_FILE = "local_file"
|
||||||
|
TOOL_FILE = "tool_file"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileTransferMethod:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileBelongsTo(str, Enum):
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileBelongsTo:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileAttribute(str, Enum):
|
||||||
|
TYPE = "type"
|
||||||
|
SIZE = "size"
|
||||||
|
NAME = "name"
|
||||||
|
MIME_TYPE = "mime_type"
|
||||||
|
TRANSFER_METHOD = "transfer_method"
|
||||||
|
URL = "url"
|
||||||
|
EXTENSION = "extension"
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFileAttribute(str, Enum):
|
||||||
|
LENGTH = "length"
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models import UploadFile
|
||||||
|
|
||||||
|
from . import helpers
|
||||||
|
from .enums import FileAttribute
|
||||||
|
from .models import File, FileTransferMethod, FileType
|
||||||
|
from .tool_file_parser import ToolFileParser
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr(*, file: "File", attr: "FileAttribute"):
|
||||||
|
match attr:
|
||||||
|
case FileAttribute.TYPE:
|
||||||
|
return file.type.value
|
||||||
|
case FileAttribute.SIZE:
|
||||||
|
return file.size
|
||||||
|
case FileAttribute.NAME:
|
||||||
|
return file.filename
|
||||||
|
case FileAttribute.MIME_TYPE:
|
||||||
|
return file.mime_type
|
||||||
|
case FileAttribute.TRANSFER_METHOD:
|
||||||
|
return file.transfer_method.value
|
||||||
|
case FileAttribute.URL:
|
||||||
|
return file.remote_url
|
||||||
|
case FileAttribute.EXTENSION:
|
||||||
|
return file.extension
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid file attribute: {attr}")
|
||||||
|
|
||||||
|
|
||||||
|
def to_prompt_message_content(file: "File", /):
|
||||||
|
"""
|
||||||
|
Convert a File object to an ImagePromptMessageContent object.
|
||||||
|
|
||||||
|
This function takes a File object and converts it to an ImagePromptMessageContent
|
||||||
|
object, which can be used as a prompt for image-based AI models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImagePromptMessageContent: An object containing the image data and detail level.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the file is not an image or if the file data is missing.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The detail level of the image prompt is determined by the file's extra_config.
|
||||||
|
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||||
|
"""
|
||||||
|
if file.type != FileType.IMAGE:
|
||||||
|
raise ValueError("Only image file can convert to prompt message content")
|
||||||
|
|
||||||
|
url_or_b64_data = _get_url_or_b64_data(file=file)
|
||||||
|
if url_or_b64_data is None:
|
||||||
|
raise ValueError("Missing file data")
|
||||||
|
|
||||||
|
# decide the detail of image prompt message content
|
||||||
|
if file._extra_config and file._extra_config.image_config and file._extra_config.image_config.detail:
|
||||||
|
detail = file._extra_config.image_config.detail
|
||||||
|
else:
|
||||||
|
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
|
||||||
|
return ImagePromptMessageContent(data=url_or_b64_data, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
|
def download(*, upload_file_id: str, tenant_id: str):
|
||||||
|
upload_file = (
|
||||||
|
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
raise ValueError("upload file not found")
|
||||||
|
|
||||||
|
return _download(upload_file.key)
|
||||||
|
|
||||||
|
|
||||||
|
def _download(path: str, /):
|
||||||
|
"""
|
||||||
|
Download and return the contents of a file as bytes.
|
||||||
|
|
||||||
|
This function loads the file from storage and ensures it's in bytes format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the file in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The contents of the file as a bytes object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the loaded file is not a bytes object.
|
||||||
|
"""
|
||||||
|
data = storage.load(path, stream=False)
|
||||||
|
if not isinstance(data, bytes):
|
||||||
|
raise ValueError(f"file {path} is not a bytes object")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base64(*, upload_file_id: str, tenant_id: str) -> str | None:
|
||||||
|
upload_file = (
|
||||||
|
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = _download(upload_file.key)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||||
|
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url_or_b64_data(file: "File"):
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
|
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return file.remote_url
|
||||||
|
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if file.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
|
||||||
|
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||||
|
return helpers.get_signed_image_url(upload_file_id=file.related_id)
|
||||||
|
return _get_base64(upload_file_id=file.related_id, tenant_id=file.tenant_id)
|
||||||
|
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
# add sign url
|
||||||
|
if file.related_id is None or file.extension is None:
|
||||||
|
raise ValueError("Missing file related_id or extension")
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=file.related_id, extension=file.extension
|
||||||
|
)
|
||||||
|
|
@ -1,145 +0,0 @@
|
||||||
import enum
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.file.tool_file_parser import ToolFileParser
|
|
||||||
from core.file.upload_file_parser import UploadFileParser
|
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
|
|
||||||
class FileExtraConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
File Upload Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
image_config: Optional[dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FileType(enum.Enum):
|
|
||||||
IMAGE = "image"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileType:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileTransferMethod(enum.Enum):
|
|
||||||
REMOTE_URL = "remote_url"
|
|
||||||
LOCAL_FILE = "local_file"
|
|
||||||
TOOL_FILE = "tool_file"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileTransferMethod:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileBelongsTo(enum.Enum):
|
|
||||||
USER = "user"
|
|
||||||
ASSISTANT = "assistant"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileBelongsTo:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileVar(BaseModel):
|
|
||||||
id: Optional[str] = None # message file id
|
|
||||||
tenant_id: str
|
|
||||||
type: FileType
|
|
||||||
transfer_method: FileTransferMethod
|
|
||||||
url: Optional[str] = None # remote url
|
|
||||||
related_id: Optional[str] = None
|
|
||||||
extra_config: Optional[FileExtraConfig] = None
|
|
||||||
filename: Optional[str] = None
|
|
||||||
extension: Optional[str] = None
|
|
||||||
mime_type: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"__variant": self.__class__.__name__,
|
|
||||||
"tenant_id": self.tenant_id,
|
|
||||||
"type": self.type.value,
|
|
||||||
"transfer_method": self.transfer_method.value,
|
|
||||||
"url": self.preview_url,
|
|
||||||
"remote_url": self.url,
|
|
||||||
"related_id": self.related_id,
|
|
||||||
"filename": self.filename,
|
|
||||||
"extension": self.extension,
|
|
||||||
"mime_type": self.mime_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_markdown(self) -> str:
|
|
||||||
"""
|
|
||||||
Convert file to markdown
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
preview_url = self.preview_url
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
text = f''
|
|
||||||
else:
|
|
||||||
text = f"[{self.filename or preview_url}]({preview_url})"
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get image data, file signed url or base64 data
|
|
||||||
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_data()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def preview_url(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get signed preview url
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_data(force_url=True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
image_config = self.extra_config.image_config
|
|
||||||
|
|
||||||
return ImagePromptMessageContent(
|
|
||||||
data=self.data,
|
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
|
||||||
if image_config.get("detail") == "high"
|
|
||||||
else ImagePromptMessageContent.DETAIL.LOW,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
|
||||||
from models.model import UploadFile
|
|
||||||
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
|
||||||
return self.url
|
|
||||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
|
||||||
upload_file = (
|
|
||||||
db.session.query(UploadFile)
|
|
||||||
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
|
|
||||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
|
||||||
extension = self.extension
|
|
||||||
# add sign url
|
|
||||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
|
||||||
tool_file_id=self.related_id, extension=extension
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_signed_image_url(upload_file_id: str) -> str:
|
||||||
|
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/image-preview"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
key = dify_config.SECRET_KEY.encode()
|
||||||
|
msg = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_signed_file_url(upload_file_id: str) -> str:
|
||||||
|
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
key = dify_config.SECRET_KEY.encode()
|
||||||
|
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode()
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode()
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
@ -1,225 +0,0 @@
|
||||||
import re
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from typing import Any, Union
|
|
||||||
from urllib.parse import parse_qs, urlparse
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.account import Account
|
|
||||||
from models.model import EndUser, MessageFile, UploadFile
|
|
||||||
from services.file_service import IMAGE_EXTENSIONS
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFileParser:
|
|
||||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.app_id = app_id
|
|
||||||
|
|
||||||
def validate_and_transform_files_arg(
|
|
||||||
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
|
|
||||||
) -> list[FileVar]:
|
|
||||||
"""
|
|
||||||
validate and transform files arg
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:param user:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
if not isinstance(file, dict):
|
|
||||||
raise ValueError("Invalid file format, must be dict")
|
|
||||||
if not file.get("type"):
|
|
||||||
raise ValueError("Missing file type")
|
|
||||||
FileType.value_of(file.get("type"))
|
|
||||||
if not file.get("transfer_method"):
|
|
||||||
raise ValueError("Missing file transfer method")
|
|
||||||
FileTransferMethod.value_of(file.get("transfer_method"))
|
|
||||||
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
|
|
||||||
if not file.get("url"):
|
|
||||||
raise ValueError("Missing file url")
|
|
||||||
if not file.get("url").startswith("http"):
|
|
||||||
raise ValueError("Invalid file url")
|
|
||||||
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
|
|
||||||
raise ValueError("Missing file upload_file_id")
|
|
||||||
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
|
|
||||||
raise ValueError("Missing file tool_file_id")
|
|
||||||
|
|
||||||
# transform files to file objs
|
|
||||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
|
||||||
|
|
||||||
# validate files
|
|
||||||
new_files = []
|
|
||||||
for file_type, file_objs in type_file_objs.items():
|
|
||||||
if file_type == FileType.IMAGE:
|
|
||||||
# parse and validate files
|
|
||||||
image_config = file_extra_config.image_config
|
|
||||||
|
|
||||||
# check if image file feature is enabled
|
|
||||||
if not image_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Validate number of files
|
|
||||||
if len(files) > image_config["number_limits"]:
|
|
||||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
|
||||||
|
|
||||||
for file_obj in file_objs:
|
|
||||||
# Validate transfer method
|
|
||||||
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
|
|
||||||
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
|
|
||||||
|
|
||||||
# Validate file type
|
|
||||||
if file_obj.type != FileType.IMAGE:
|
|
||||||
raise ValueError(f"Invalid file type: {file_obj.type}")
|
|
||||||
|
|
||||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
|
||||||
# check remote url valid and is image
|
|
||||||
result, error = self._check_image_remote_url(file_obj.url)
|
|
||||||
if result is False:
|
|
||||||
raise ValueError(error)
|
|
||||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
|
||||||
# get upload file from upload_file_id
|
|
||||||
upload_file = (
|
|
||||||
db.session.query(UploadFile)
|
|
||||||
.filter(
|
|
||||||
UploadFile.id == file_obj.related_id,
|
|
||||||
UploadFile.tenant_id == self.tenant_id,
|
|
||||||
UploadFile.created_by == user.id,
|
|
||||||
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
|
||||||
UploadFile.extension.in_(IMAGE_EXTENSIONS),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# check upload file is belong to tenant and user
|
|
||||||
if not upload_file:
|
|
||||||
raise ValueError("Invalid upload file")
|
|
||||||
|
|
||||||
new_files.append(file_obj)
|
|
||||||
|
|
||||||
# return all file objs
|
|
||||||
return new_files
|
|
||||||
|
|
||||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
|
||||||
"""
|
|
||||||
transform message files
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# transform files to file objs
|
|
||||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
|
||||||
|
|
||||||
# return all file objs
|
|
||||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
|
||||||
|
|
||||||
def _to_file_objs(
|
|
||||||
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
|
|
||||||
) -> dict[FileType, list[FileVar]]:
|
|
||||||
"""
|
|
||||||
transform files to file objs
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
type_file_objs: dict[FileType, list[FileVar]] = {
|
|
||||||
# Currently only support image
|
|
||||||
FileType.IMAGE: []
|
|
||||||
}
|
|
||||||
|
|
||||||
if not files:
|
|
||||||
return type_file_objs
|
|
||||||
|
|
||||||
# group by file type and convert file args or message files to FileObj
|
|
||||||
for file in files:
|
|
||||||
if isinstance(file, MessageFile):
|
|
||||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_obj = self._to_file_obj(file, file_extra_config)
|
|
||||||
if file_obj.type not in type_file_objs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
type_file_objs[file_obj.type].append(file_obj)
|
|
||||||
|
|
||||||
return type_file_objs
|
|
||||||
|
|
||||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
|
||||||
"""
|
|
||||||
transform file to file obj
|
|
||||||
|
|
||||||
:param file:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if isinstance(file, dict):
|
|
||||||
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
|
|
||||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
|
||||||
return FileVar(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.get("type")),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
|
||||||
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
return FileVar(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.get("type")),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
url=None,
|
|
||||||
related_id=file.get("tool_file_id"),
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return FileVar(
|
|
||||||
id=file.id,
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.type),
|
|
||||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
|
||||||
url=file.url,
|
|
||||||
related_id=file.upload_file_id or None,
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_image_remote_url(self, url):
|
|
||||||
try:
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
|
||||||
" Chrome/91.0.4472.124 Safari/537.36"
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_s3_presigned_url(url):
|
|
||||||
try:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
if "amazonaws.com" not in parsed_url.netloc:
|
|
||||||
return False
|
|
||||||
query_params = parse_qs(parsed_url.query)
|
|
||||||
required_params = ["Signature", "Expires"]
|
|
||||||
for param in required_params:
|
|
||||||
if param not in query_params:
|
|
||||||
return False
|
|
||||||
if not query_params["Expires"][0].isdigit():
|
|
||||||
return False
|
|
||||||
signature = query_params["Signature"][0]
|
|
||||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if is_s3_presigned_url(url):
|
|
||||||
response = requests.get(url, headers=headers, allow_redirects=True)
|
|
||||||
if response.status_code in {200, 304}:
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
|
||||||
if response.status_code in {200, 304}:
|
|
||||||
return True, ""
|
|
||||||
else:
|
|
||||||
return False, "URL does not exist."
|
|
||||||
except requests.RequestException as e:
|
|
||||||
return False, f"Error checking URL: {e}"
|
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
|
|
||||||
|
from . import helpers
|
||||||
|
from .constants import FILE_MODEL_IDENTITY
|
||||||
|
from .enums import FileTransferMethod, FileType
|
||||||
|
from .tool_file_parser import ToolFileParser
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||||
|
"""
|
||||||
|
|
||||||
|
number_limits: int = 0
|
||||||
|
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileExtraConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
File Upload Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_config: Optional[ImageConfig] = None
|
||||||
|
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||||
|
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||||
|
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
number_limits: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
model_identity: str = FILE_MODEL_IDENTITY
|
||||||
|
|
||||||
|
id: Optional[str] = None # message file id
|
||||||
|
tenant_id: str
|
||||||
|
type: FileType
|
||||||
|
transfer_method: FileTransferMethod
|
||||||
|
remote_url: Optional[str] = None # remote url
|
||||||
|
related_id: Optional[str] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
extension: Optional[str] = None
|
||||||
|
mime_type: Optional[str] = None
|
||||||
|
size: int = 0
|
||||||
|
_extra_config: FileExtraConfig | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||||
|
data = self.model_dump()
|
||||||
|
return {
|
||||||
|
**data,
|
||||||
|
"url": self.generate_url(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
url = self.generate_url()
|
||||||
|
if self.type == FileType.IMAGE:
|
||||||
|
text = f''
|
||||||
|
else:
|
||||||
|
text = f"[{self.filename or url}]({url})"
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def generate_url(self) -> Optional[str]:
|
||||||
|
if self.type == FileType.IMAGE:
|
||||||
|
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return self.remote_url
|
||||||
|
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if self.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
return helpers.get_signed_image_url(upload_file_id=self.related_id)
|
||||||
|
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert self.related_id is not None
|
||||||
|
assert self.extension is not None
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=self.related_id, extension=self.extension
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return self.remote_url
|
||||||
|
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if self.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||||
|
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert self.related_id is not None
|
||||||
|
assert self.extension is not None
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=self.related_id, extension=self.extension
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_after(self):
|
||||||
|
match self.transfer_method:
|
||||||
|
case FileTransferMethod.REMOTE_URL:
|
||||||
|
if not self.remote_url:
|
||||||
|
raise ValueError("Missing file url")
|
||||||
|
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
|
||||||
|
raise ValueError("Invalid file url")
|
||||||
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
|
if not self.related_id:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
case FileTransferMethod.TOOL_FILE:
|
||||||
|
if not self.related_id:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
|
||||||
|
# Validate the extra config.
|
||||||
|
if not self._extra_config:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self._extra_config.allowed_file_types:
|
||||||
|
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
||||||
|
raise ValueError(f"Invalid file type: {self.type}")
|
||||||
|
|
||||||
|
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
||||||
|
raise ValueError(f"Invalid file extension: {self.extension}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._extra_config.allowed_upload_methods
|
||||||
|
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||||
|
|
||||||
|
match self.type:
|
||||||
|
case FileType.IMAGE:
|
||||||
|
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||||
|
if not self._extra_config.image_config:
|
||||||
|
return self
|
||||||
|
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
||||||
|
if (
|
||||||
|
self._extra_config.image_config.transfer_methods
|
||||||
|
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
@ -1,4 +1,9 @@
|
||||||
tool_file_manager = {"manager": None}
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
|
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||||
|
|
||||||
|
|
||||||
class ToolFileParser:
|
class ToolFileParser:
|
||||||
|
|
|
||||||
|
|
@ -1,79 +0,0 @@
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from extensions.ext_storage import storage
|
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
|
||||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
|
||||||
|
|
||||||
|
|
||||||
class UploadFileParser:
|
|
||||||
@classmethod
|
|
||||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
|
||||||
if not upload_file:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
|
|
||||||
return cls.get_signed_temp_image_url(upload_file.id)
|
|
||||||
else:
|
|
||||||
# get image file base64
|
|
||||||
try:
|
|
||||||
data = storage.load(upload_file.key)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error(f"File not found: {upload_file.key}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
|
||||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
|
||||||
"""
|
|
||||||
get signed url from upload file
|
|
||||||
|
|
||||||
:param upload_file: UploadFile object
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
base_url = dify_config.FILES_URL
|
|
||||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
|
||||||
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
nonce = os.urandom(16).hex()
|
|
||||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
|
||||||
secret_key = dify_config.SECRET_KEY.encode()
|
|
||||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
|
||||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
|
||||||
|
|
||||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
|
||||||
"""
|
|
||||||
verify signature
|
|
||||||
|
|
||||||
:param upload_file_id: file id
|
|
||||||
:param timestamp: timestamp
|
|
||||||
:param nonce: nonce
|
|
||||||
:param sign: signature
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
|
||||||
secret_key = dify_config.SECRET_KEY.encode()
|
|
||||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
|
||||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
|
||||||
|
|
||||||
# verify signature
|
|
||||||
if sign != recalculated_encoded_sign:
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_time = int(time.time())
|
|
||||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
|
||||||
|
|
@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
|
||||||
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
|
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
|
||||||
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
|
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
|
||||||
|
|
||||||
proxies = (
|
proxy_mounts = (
|
||||||
{"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL}
|
{
|
||||||
|
"http://": httpx.HTTPTransport(SSRF_PROXY_HTTP_URL),
|
||||||
|
"https://": httpx.HTTPTransport(SSRF_PROXY_HTTPS_URL),
|
||||||
|
}
|
||||||
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
|
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
if SSRF_PROXY_ALL_URL:
|
if SSRF_PROXY_ALL_URL:
|
||||||
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
|
||||||
elif proxies:
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
elif proxy_mounts:
|
||||||
|
with httpx.Client(mounts=proxy_mounts) as client:
|
||||||
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
else:
|
else:
|
||||||
response = httpx.request(method=method, url=url, **kwargs)
|
with httpx.Client() as client:
|
||||||
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
|
|
||||||
if response.status_code not in STATUS_FORCELIST:
|
if response.status_code not in STATUS_FORCELIST:
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,20 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.file.message_file_parser import MessageFileParser
|
from core.file import file_manager
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models.model import AppMode, Conversation, Message, MessageFile
|
from models.model import AppMode, Conversation, Message, MessageFile
|
||||||
from models.workflow import WorkflowRun
|
from models.workflow import WorkflowRun
|
||||||
|
|
||||||
|
|
@ -61,7 +63,6 @@ class TokenBufferMemory:
|
||||||
thread_messages.pop(0)
|
thread_messages.pop(0)
|
||||||
messages = list(reversed(thread_messages))
|
messages = list(reversed(thread_messages))
|
||||||
|
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||||
|
|
@ -80,17 +81,20 @@ class TokenBufferMemory:
|
||||||
workflow_run.workflow.features_dict, is_vision=False
|
workflow_run.workflow.features_dict, is_vision=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_extra_config:
|
if file_extra_config and app_record:
|
||||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
file_objs = file_factory.build_from_message_files(
|
||||||
|
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
if not file_objs:
|
if not file_objs:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
|
from .message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageRole,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from .model_entities import ModelPropertyKey
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImagePromptMessageContent",
|
||||||
|
"PromptMessage",
|
||||||
|
"PromptMessageRole",
|
||||||
|
"LLMUsage",
|
||||||
|
"ModelPropertyKey",
|
||||||
|
"AssistantPromptMessage",
|
||||||
|
"PromptMessage",
|
||||||
|
"PromptMessageContent",
|
||||||
|
"PromptMessageRole",
|
||||||
|
"SystemPromptMessage",
|
||||||
|
"TextPromptMessageContent",
|
||||||
|
"UserPromptMessage",
|
||||||
|
"PromptMessageTool",
|
||||||
|
"ToolPromptMessage",
|
||||||
|
"PromptMessageContentType",
|
||||||
|
"LLMResult",
|
||||||
|
"LLMResultChunk",
|
||||||
|
"LLMResultChunkDelta",
|
||||||
|
]
|
||||||
|
|
@ -79,7 +79,7 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||||
Model class for image prompt message content.
|
Model class for image prompt message content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class DETAIL(Enum):
|
class DETAIL(str, Enum):
|
||||||
LOW = "low"
|
LOW = "low"
|
||||||
HIGH = "high"
|
HIGH = "high"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
@ -8,6 +7,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
|
|
@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel):
|
||||||
|
|
||||||
callbacks = callbacks or []
|
callbacks = callbacks or []
|
||||||
|
|
||||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
if dify_config.DEBUG:
|
||||||
callbacks.append(LoggingCallback())
|
callbacks.append(LoggingCallback())
|
||||||
|
|
||||||
# trigger before invoke callbacks
|
# trigger before invoke callbacks
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
from typing import Optional, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.file.file_obj import FileVar
|
from core.file import file_manager
|
||||||
|
from core.file.models import File
|
||||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
|
@ -14,7 +17,6 @@ from core.model_runtime.entities.message_entities import (
|
||||||
)
|
)
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.prompt_transform import PromptTransform
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,22 +30,19 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
*,
|
||||||
inputs: dict,
|
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
|
||||||
|
inputs: dict[str, str],
|
||||||
query: str,
|
query: str,
|
||||||
files: list[FileVar],
|
files: Sequence[File],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
query_prompt_template: Optional[str] = None,
|
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
inputs = {key: str(value) for key, value in inputs.items()}
|
|
||||||
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
|
||||||
model_mode = ModelMode.value_of(model_config.mode)
|
if isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||||
if model_mode == ModelMode.COMPLETION:
|
|
||||||
prompt_messages = self._get_completion_model_prompt_messages(
|
prompt_messages = self._get_completion_model_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
|
@ -54,12 +53,11 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
elif model_mode == ModelMode.CHAT:
|
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||||
prompt_messages = self._get_chat_model_prompt_messages(
|
prompt_messages = self._get_chat_model_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query,
|
query=query,
|
||||||
query_prompt_template=query_prompt_template,
|
|
||||||
files=files,
|
files=files,
|
||||||
context=context,
|
context=context,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
|
|
@ -74,7 +72,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt_template: CompletionModelPromptTemplate,
|
prompt_template: CompletionModelPromptTemplate,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
files: list[FileVar],
|
files: Sequence[File],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
|
|
@ -88,10 +86,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
|
||||||
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
|
|
||||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||||
|
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
role_prefix = memory_config.role_prefix
|
role_prefix = memory_config.role_prefix
|
||||||
|
|
@ -100,15 +98,15 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
raw_prompt=raw_prompt,
|
raw_prompt=raw_prompt,
|
||||||
role_prefix=role_prefix,
|
role_prefix=role_prefix,
|
||||||
prompt_template=prompt_template,
|
parser=parser,
|
||||||
prompt_inputs=prompt_inputs,
|
prompt_inputs=prompt_inputs,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
|
||||||
|
|
||||||
prompt = prompt_template.format(prompt_inputs)
|
prompt = parser.format(prompt_inputs)
|
||||||
else:
|
else:
|
||||||
prompt = raw_prompt
|
prompt = raw_prompt
|
||||||
prompt_inputs = inputs
|
prompt_inputs = inputs
|
||||||
|
|
@ -116,9 +114,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
|
@ -131,35 +130,28 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt_template: list[ChatModelMessage],
|
prompt_template: list[ChatModelMessage],
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
files: list[FileVar],
|
files: Sequence[File],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
query_prompt_template: Optional[str] = None,
|
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Get chat model prompt messages.
|
Get chat model prompt messages.
|
||||||
"""
|
"""
|
||||||
raw_prompt_list = prompt_template
|
|
||||||
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
for prompt_item in prompt_template:
|
||||||
for prompt_item in raw_prompt_list:
|
|
||||||
raw_prompt = prompt_item.text
|
raw_prompt = prompt_item.text
|
||||||
|
|
||||||
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
|
prompt_inputs = self._set_context_variable(context=context, parser=parser, prompt_inputs=prompt_inputs)
|
||||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
prompt = parser.format(prompt_inputs)
|
||||||
|
|
||||||
prompt = prompt_template.format(prompt_inputs)
|
|
||||||
elif prompt_item.edition_type == "jinja2":
|
elif prompt_item.edition_type == "jinja2":
|
||||||
prompt = raw_prompt
|
prompt = raw_prompt
|
||||||
prompt_inputs = inputs
|
prompt_inputs = inputs
|
||||||
|
prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
|
||||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
||||||
|
|
||||||
|
|
@ -170,25 +162,25 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||||
|
|
||||||
if query and query_prompt_template:
|
if query and memory_config and memory_config.query_prompt_template:
|
||||||
prompt_template = PromptTemplateParser(
|
parser = PromptTemplateParser(
|
||||||
template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
||||||
)
|
)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
prompt_inputs["#sys.query#"] = query
|
prompt_inputs["#sys.query#"] = query
|
||||||
|
|
||||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||||
|
|
||||||
query = prompt_template.format(prompt_inputs)
|
query = parser.format(prompt_inputs)
|
||||||
|
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||||
|
|
||||||
if files:
|
if files and query is not None:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=query))
|
prompt_messages.append(UserPromptMessage(content=query))
|
||||||
|
|
@ -200,19 +192,19 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
# get last user message content and add files
|
# get last user message content and add files
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
last_message.content = prompt_message_contents
|
last_message.content = prompt_message_contents
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
elif query:
|
elif query:
|
||||||
|
|
@ -220,8 +212,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||||
if "#context#" in prompt_template.variable_keys:
|
if "#context#" in parser.variable_keys:
|
||||||
if context:
|
if context:
|
||||||
prompt_inputs["#context#"] = context
|
prompt_inputs["#context#"] = context
|
||||||
else:
|
else:
|
||||||
|
|
@ -229,8 +221,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
|
|
||||||
return prompt_inputs
|
return prompt_inputs
|
||||||
|
|
||||||
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||||
if "#query#" in prompt_template.variable_keys:
|
if "#query#" in parser.variable_keys:
|
||||||
if query:
|
if query:
|
||||||
prompt_inputs["#query#"] = query
|
prompt_inputs["#query#"] = query
|
||||||
else:
|
else:
|
||||||
|
|
@ -244,16 +236,16 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
raw_prompt: str,
|
raw_prompt: str,
|
||||||
role_prefix: MemoryConfig.RolePrefix,
|
role_prefix: MemoryConfig.RolePrefix,
|
||||||
prompt_template: PromptTemplateParser,
|
parser: PromptTemplateParser,
|
||||||
prompt_inputs: dict,
|
prompt_inputs: dict,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if "#histories#" in prompt_template.variable_keys:
|
if "#histories#" in parser.variable_keys:
|
||||||
if memory:
|
if memory:
|
||||||
inputs = {"#histories#": "", **prompt_inputs}
|
inputs = {"#histories#": "", **prompt_inputs}
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs))
|
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||||
|
|
||||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import PromptTemplateEntity
|
from core.app.app_config.entities import PromptTemplateEntity
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
|
from core.file import file_manager
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
|
|
@ -18,7 +20,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
class ModelMode(enum.Enum):
|
class ModelMode(enum.Enum):
|
||||||
|
|
@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform):
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
|
@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform):
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
|
|
@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform):
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
|
|
@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform):
|
||||||
|
|
||||||
return [self.get_last_user_message(prompt, files)], stops
|
return [self.get_last_user_message(prompt, files)], stops
|
||||||
|
|
||||||
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
|
def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,8 @@ class UserToolProvider(BaseModel):
|
||||||
original_credentials: Optional[dict] = None
|
original_credentials: Optional[dict] = None
|
||||||
is_team_authorization: bool = False
|
is_team_authorization: bool = False
|
||||||
allow_delete: bool = True
|
allow_delete: bool = True
|
||||||
tools: list[UserTool] = None
|
tools: list[UserTool] | None = None
|
||||||
labels: list[str] = None
|
labels: list[str] | None = None
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
# -------------
|
# -------------
|
||||||
|
|
@ -42,7 +42,7 @@ class UserToolProvider(BaseModel):
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.get("parameters"):
|
if tool.get("parameters"):
|
||||||
for parameter in tool.get("parameters"):
|
for parameter in tool.get("parameters"):
|
||||||
if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
|
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||||
parameter["type"] = "files"
|
parameter["type"] = "files"
|
||||||
# -------------
|
# -------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ class ToolInvokeMessage(BaseModel):
|
||||||
BLOB = "blob"
|
BLOB = "blob"
|
||||||
JSON = "json"
|
JSON = "json"
|
||||||
IMAGE_LINK = "image_link"
|
IMAGE_LINK = "image_link"
|
||||||
FILE_VAR = "file_var"
|
FILE = "file"
|
||||||
|
|
||||||
type: MessageType = MessageType.TEXT
|
type: MessageType = MessageType.TEXT
|
||||||
"""
|
"""
|
||||||
|
|
@ -143,6 +143,67 @@ class ToolParameter(BaseModel):
|
||||||
SELECT = "select"
|
SELECT = "select"
|
||||||
SECRET_INPUT = "secret-input"
|
SECRET_INPUT = "secret-input"
|
||||||
FILE = "file"
|
FILE = "file"
|
||||||
|
FILES = "files"
|
||||||
|
|
||||||
|
# deprecated, should not use.
|
||||||
|
SYSTEM_FILES = "systme-files"
|
||||||
|
|
||||||
|
def as_normal_type(self):
|
||||||
|
if self in {
|
||||||
|
ToolParameter.ToolParameterType.SECRET_INPUT,
|
||||||
|
ToolParameter.ToolParameterType.SELECT,
|
||||||
|
}:
|
||||||
|
return "string"
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def cast_value(self, value: Any, /):
|
||||||
|
try:
|
||||||
|
match self:
|
||||||
|
case (
|
||||||
|
ToolParameter.ToolParameterType.STRING
|
||||||
|
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
| ToolParameter.ToolParameterType.SELECT
|
||||||
|
):
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return value if isinstance(value, str) else str(value)
|
||||||
|
|
||||||
|
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
elif isinstance(value, str):
|
||||||
|
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||||
|
# and also '0' for False and '1' for True
|
||||||
|
match value.lower():
|
||||||
|
case "true" | "yes" | "y" | "1":
|
||||||
|
return True
|
||||||
|
case "false" | "no" | "n" | "0":
|
||||||
|
return False
|
||||||
|
case _:
|
||||||
|
return bool(value)
|
||||||
|
else:
|
||||||
|
return value if isinstance(value, bool) else bool(value)
|
||||||
|
|
||||||
|
case ToolParameter.ToolParameterType.NUMBER:
|
||||||
|
if isinstance(value, int | float):
|
||||||
|
return value
|
||||||
|
elif isinstance(value, str) and value:
|
||||||
|
if "." in value:
|
||||||
|
return float(value)
|
||||||
|
else:
|
||||||
|
return int(value)
|
||||||
|
case (
|
||||||
|
ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||||
|
| ToolParameter.ToolParameterType.FILE
|
||||||
|
| ToolParameter.ToolParameterType.FILES
|
||||||
|
):
|
||||||
|
return value
|
||||||
|
case _:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
|
||||||
|
|
||||||
class ToolParameterForm(Enum):
|
class ToolParameterForm(Enum):
|
||||||
SCHEMA = "schema" # should be set while adding tool
|
SCHEMA = "schema" # should be set while adding tool
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import Any
|
||||||
|
|
||||||
from duckduckgo_search import DDGS
|
from duckduckgo_search import DDGS
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod
|
from core.file.models import FileTransferMethod
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from core.tools.errors import (
|
||||||
from core.tools.provider.tool_provider import ToolProviderController
|
from core.tools.provider.tool_provider import ToolProviderController
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
from core.tools.utils.yaml_utils import load_yaml_file
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||||
|
|
||||||
# the parameter is not set currently, set the default value if needed
|
# the parameter is not set currently, set the default value if needed
|
||||||
if parameter_schema.default is not None:
|
if parameter_schema.default is not None:
|
||||||
default_value = ToolParameterConverter.cast_parameter_by_type(
|
default_value = parameter_schema.type.cast_value(parameter_schema.default)
|
||||||
parameter_schema.default, parameter_schema.type
|
|
||||||
)
|
|
||||||
tool_parameters[parameter] = default_value
|
tool_parameters[parameter] = default_value
|
||||||
|
|
||||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import (
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderController(BaseModel, ABC):
|
class ToolProviderController(BaseModel, ABC):
|
||||||
|
|
@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||||
|
|
||||||
# the parameter is not set currently, set the default value if needed
|
# the parameter is not set currently, set the default value if needed
|
||||||
if parameter_schema.default is not None:
|
if parameter_schema.default is not None:
|
||||||
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(
|
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
|
||||||
parameter_schema.default, parameter_schema.type
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
|
|
@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||||
|
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||||
|
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError("app not found")
|
raise ValueError("app not found")
|
||||||
|
|
||||||
controller = WorkflowToolProviderController(
|
controller = WorkflowToolProviderController.model_validate(
|
||||||
**{
|
{
|
||||||
"identity": {
|
"identity": {
|
||||||
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
||||||
"name": db_provider.label,
|
"name": db_provider.label,
|
||||||
|
|
@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
:param app: the app
|
:param app: the app
|
||||||
:return: the tool
|
:return: the tool
|
||||||
"""
|
"""
|
||||||
workflow: Workflow = (
|
workflow = (
|
||||||
db.session.query(Workflow)
|
db.session.query(Workflow)
|
||||||
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||||
.first()
|
.first()
|
||||||
|
|
@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
raise ValueError("workflow not found")
|
raise ValueError("workflow not found")
|
||||||
|
|
||||||
# fetch start node
|
# fetch start node
|
||||||
graph: dict = workflow.graph_dict
|
graph = workflow.graph_dict
|
||||||
features_dict: dict = workflow.features_dict
|
features_dict = workflow.features_dict
|
||||||
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
|
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
|
||||||
|
|
||||||
parameters = db_provider.parameter_configurations
|
parameters = db_provider.parameter_configurations
|
||||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||||
|
|
||||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
|
def fetch_workflow_variable(variable_name: str):
|
||||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||||
|
|
||||||
user = db_provider.user
|
user = db_provider.user
|
||||||
|
|
@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
llm_description=parameter.description,
|
llm_description=parameter.description,
|
||||||
required=variable.required,
|
required=variable.required,
|
||||||
options=options,
|
options=options,
|
||||||
default=variable.default,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif features.file_upload:
|
elif features.file_upload:
|
||||||
|
|
@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
name=parameter.name,
|
name=parameter.name,
|
||||||
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
||||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
||||||
type=ToolParameter.ToolParameterType.FILE,
|
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||||
llm_description=parameter.description,
|
llm_description=parameter.description,
|
||||||
required=False,
|
required=False,
|
||||||
form=parameter.form,
|
form=parameter.form,
|
||||||
|
|
|
||||||
|
|
@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import (
|
||||||
ToolRuntimeVariablePool,
|
ToolRuntimeVariablePool,
|
||||||
)
|
)
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel, ABC):
|
class Tool(BaseModel, ABC):
|
||||||
|
|
@ -221,9 +220,7 @@ class Tool(BaseModel, ABC):
|
||||||
result = deepcopy(tool_parameters)
|
result = deepcopy(tool_parameters)
|
||||||
for parameter in self.parameters or []:
|
for parameter in self.parameters or []:
|
||||||
if parameter.name in tool_parameters:
|
if parameter.name in tool_parameters:
|
||||||
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
|
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||||
tool_parameters[parameter.name], parameter.type
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -295,10 +292,8 @@ class Tool(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
|
||||||
|
|
||||||
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||||
return ToolInvokeMessage(
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
|
||||||
type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as=""
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
|
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod, FileVar
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
@ -45,11 +45,13 @@ class WorkflowTool(Tool):
|
||||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||||
|
|
||||||
# transform the tool parameters
|
# transform the tool parameters
|
||||||
tool_parameters, files = self._transform_args(tool_parameters)
|
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
|
||||||
|
|
||||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
|
|
||||||
generator = WorkflowAppGenerator()
|
generator = WorkflowAppGenerator()
|
||||||
|
assert self.runtime is not None
|
||||||
|
assert self.runtime.invoke_from is not None
|
||||||
result = generator.generate(
|
result = generator.generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
|
|
@ -71,7 +73,7 @@ class WorkflowTool(Tool):
|
||||||
outputs = data.get("outputs", {})
|
outputs = data.get("outputs", {})
|
||||||
outputs, files = self._extract_files(outputs)
|
outputs, files = self._extract_files(outputs)
|
||||||
for file in files:
|
for file in files:
|
||||||
result.append(self.create_file_var_message(file))
|
result.append(self.create_file_message(file))
|
||||||
|
|
||||||
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||||
result.append(self.create_json_message(outputs))
|
result.append(self.create_json_message(outputs))
|
||||||
|
|
@ -151,22 +153,22 @@ class WorkflowTool(Tool):
|
||||||
parameters_result = {}
|
parameters_result = {}
|
||||||
files = []
|
files = []
|
||||||
for parameter in parameter_rules:
|
for parameter in parameter_rules:
|
||||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||||
file = tool_parameters.get(parameter.name)
|
file = tool_parameters.get(parameter.name)
|
||||||
if file:
|
if file:
|
||||||
try:
|
try:
|
||||||
file_var_list = [FileVar(**f) for f in file]
|
file_var_list = [File.model_validate(f) for f in file]
|
||||||
for file_var in file_var_list:
|
for file in file_var_list:
|
||||||
file_dict = {
|
file_dict: dict[str, str | None] = {
|
||||||
"transfer_method": file_var.transfer_method.value,
|
"transfer_method": file.transfer_method.value,
|
||||||
"type": file_var.type.value,
|
"type": file.type.value,
|
||||||
}
|
}
|
||||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
file_dict["tool_file_id"] = file_var.related_id
|
file_dict["tool_file_id"] = file.related_id
|
||||||
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
|
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
file_dict["upload_file_id"] = file_var.related_id
|
file_dict["upload_file_id"] = file.related_id
|
||||||
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
|
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
file_dict["url"] = file_var.preview_url
|
file_dict["url"] = file.generate_url()
|
||||||
|
|
||||||
files.append(file_dict)
|
files.append(file_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -176,7 +178,7 @@ class WorkflowTool(Tool):
|
||||||
|
|
||||||
return parameters_result, files
|
return parameters_result, files
|
||||||
|
|
||||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
|
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
|
||||||
"""
|
"""
|
||||||
extract files from the result
|
extract files from the result
|
||||||
|
|
||||||
|
|
@ -187,17 +189,13 @@ class WorkflowTool(Tool):
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in outputs.items():
|
for key, value in outputs.items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
has_file = False
|
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item.get("__variant") == "FileVar":
|
if isinstance(item, dict) and item.get("model_identity") == FILE_MODEL_IDENTITY:
|
||||||
try:
|
file = File.model_validate(item)
|
||||||
files.append(FileVar(**item))
|
files.append(file)
|
||||||
has_file = True
|
elif isinstance(value, dict) and value.get("model_identity") == FILE_MODEL_IDENTITY:
|
||||||
except Exception as e:
|
file = File.model_validate(value)
|
||||||
pass
|
files.append(file)
|
||||||
if has_file:
|
|
||||||
continue
|
|
||||||
|
|
||||||
result[key] = value
|
result[key] = value
|
||||||
|
|
||||||
return result, files
|
return result, files
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from yarl import URL
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file.file_obj import FileTransferMethod
|
from core.file.models import FileTransferMethod
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
||||||
from core.tools.errors import (
|
from core.tools.errors import (
|
||||||
|
|
@ -128,6 +128,7 @@ class ToolEngine:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# hit the callback handler
|
# hit the callback handler
|
||||||
|
assert tool.identity is not None
|
||||||
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
||||||
|
|
||||||
if isinstance(tool, WorkflowTool):
|
if isinstance(tool, WorkflowTool):
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,12 @@ class ToolFileManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_file_by_raw(
|
def create_file_by_raw(
|
||||||
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
|
*,
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
file_binary: bytes,
|
||||||
|
mimetype: str,
|
||||||
) -> ToolFile:
|
) -> ToolFile:
|
||||||
"""
|
"""
|
||||||
create file
|
create file
|
||||||
|
|
@ -68,7 +73,11 @@ class ToolFileManager:
|
||||||
storage.save(filename, file_binary)
|
storage.save(filename, file_binary)
|
||||||
|
|
||||||
tool_file = ToolFile(
|
tool_file = ToolFile(
|
||||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
file_key=filename,
|
||||||
|
mimetype=mimetype,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(tool_file)
|
db.session.add(tool_file)
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
@ -203,7 +202,7 @@ class ToolManager:
|
||||||
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
|
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
|
||||||
"""
|
"""
|
||||||
init runtime parameter
|
init runtime parameter
|
||||||
"""
|
"""
|
||||||
|
|
@ -222,7 +221,7 @@ class ToolManager:
|
||||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
|
return parameter_rule.type.cast_value(parameter_value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_agent_tool_runtime(
|
def get_agent_tool_runtime(
|
||||||
|
|
@ -243,7 +242,11 @@ class ToolManager:
|
||||||
parameters = tool_entity.get_all_runtime_parameters()
|
parameters = tool_entity.get_all_runtime_parameters()
|
||||||
for parameter in parameters:
|
for parameter in parameters:
|
||||||
# check file types
|
# check file types
|
||||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
if parameter.type in {
|
||||||
|
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||||
|
ToolParameter.ToolParameterType.FILE,
|
||||||
|
ToolParameter.ToolParameterType.FILES,
|
||||||
|
}:
|
||||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||||
|
|
||||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
|
|
@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||||
class ToolFileMessageTransformer:
|
class ToolFileMessageTransformer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_tool_invoke_messages(
|
def transform_tool_invoke_messages(
|
||||||
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
|
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None
|
||||||
) -> list[ToolInvokeMessage]:
|
) -> list[ToolInvokeMessage]:
|
||||||
"""
|
"""
|
||||||
Transform tool message and handle file download
|
Transform tool message and handle file download
|
||||||
|
|
@ -21,7 +22,11 @@ class ToolFileMessageTransformer:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||||
result.append(message)
|
result.append(message)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
elif (
|
||||||
|
message.type == ToolInvokeMessage.MessageType.IMAGE
|
||||||
|
and isinstance(message.message, str)
|
||||||
|
and conversation_id
|
||||||
|
):
|
||||||
# try to download image
|
# try to download image
|
||||||
try:
|
try:
|
||||||
file = ToolFileManager.create_file_by_url(
|
file = ToolFileManager.create_file_by_url(
|
||||||
|
|
@ -50,11 +55,14 @@ class ToolFileMessageTransformer:
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||||
# get mime type and save blob to storage
|
# get mime type and save blob to storage
|
||||||
|
assert message.meta is not None
|
||||||
mimetype = message.meta.get("mime_type", "octet/stream")
|
mimetype = message.meta.get("mime_type", "octet/stream")
|
||||||
# if message is str, encode it to bytes
|
# if message is str, encode it to bytes
|
||||||
if isinstance(message.message, str):
|
if isinstance(message.message, str):
|
||||||
message.message = message.message.encode("utf-8")
|
message.message = message.message.encode("utf-8")
|
||||||
|
|
||||||
|
# FIXME: should do a type check here.
|
||||||
|
assert isinstance(message.message, bytes)
|
||||||
file = ToolFileManager.create_file_by_raw(
|
file = ToolFileManager.create_file_by_raw(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
@ -63,7 +71,7 @@ class ToolFileMessageTransformer:
|
||||||
mimetype=mimetype,
|
mimetype=mimetype,
|
||||||
)
|
)
|
||||||
|
|
||||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||||
|
|
||||||
# check if file is image
|
# check if file is image
|
||||||
if "image" in mimetype:
|
if "image" in mimetype:
|
||||||
|
|
@ -84,12 +92,14 @@ class ToolFileMessageTransformer:
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||||
file_var = message.meta.get("file_var")
|
assert message.meta is not None
|
||||||
if file_var:
|
file = message.meta.get("file")
|
||||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
if isinstance(file, File):
|
||||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
if file_var.type == FileType.IMAGE:
|
assert file.related_id is not None
|
||||||
|
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
result.append(
|
result.append(
|
||||||
ToolInvokeMessage(
|
ToolInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
|
|
@ -107,11 +117,13 @@ class ToolFileMessageTransformer:
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
result.append(message)
|
||||||
else:
|
else:
|
||||||
result.append(message)
|
result.append(message)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||||
|
|
|
||||||
|
|
@ -1,71 +0,0 @@
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolParameter
|
|
||||||
|
|
||||||
|
|
||||||
class ToolParameterConverter:
|
|
||||||
@staticmethod
|
|
||||||
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
|
|
||||||
match parameter_type:
|
|
||||||
case (
|
|
||||||
ToolParameter.ToolParameterType.STRING
|
|
||||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
|
||||||
| ToolParameter.ToolParameterType.SELECT
|
|
||||||
):
|
|
||||||
return "string"
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
|
||||||
return "boolean"
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.NUMBER:
|
|
||||||
return "number"
|
|
||||||
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unsupported parameter type {parameter_type}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
|
|
||||||
# convert tool parameter config to correct type
|
|
||||||
try:
|
|
||||||
match parameter_type:
|
|
||||||
case (
|
|
||||||
ToolParameter.ToolParameterType.STRING
|
|
||||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
|
||||||
| ToolParameter.ToolParameterType.SELECT
|
|
||||||
):
|
|
||||||
if value is None:
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
return value if isinstance(value, str) else str(value)
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
|
||||||
if value is None:
|
|
||||||
return False
|
|
||||||
elif isinstance(value, str):
|
|
||||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
|
||||||
# and also '0' for False and '1' for True
|
|
||||||
match value.lower():
|
|
||||||
case "true" | "yes" | "y" | "1":
|
|
||||||
return True
|
|
||||||
case "false" | "no" | "n" | "0":
|
|
||||||
return False
|
|
||||||
case _:
|
|
||||||
return bool(value)
|
|
||||||
else:
|
|
||||||
return value if isinstance(value, bool) else bool(value)
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.NUMBER:
|
|
||||||
if isinstance(value, int) | isinstance(value, float):
|
|
||||||
return value
|
|
||||||
elif isinstance(value, str) and value != "":
|
|
||||||
if "." in value:
|
|
||||||
return float(value)
|
|
||||||
else:
|
|
||||||
return int(value)
|
|
||||||
case ToolParameter.ToolParameterType.FILE:
|
|
||||||
return value
|
|
||||||
case _:
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
|
|
||||||
|
|
@ -1,19 +1,18 @@
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity
|
||||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
|
|
||||||
|
|
||||||
class WorkflowToolConfigurationUtils:
|
class WorkflowToolConfigurationUtils:
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_parameter_configurations(cls, configurations: list[dict]):
|
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
|
||||||
"""
|
|
||||||
check parameter configurations
|
|
||||||
"""
|
|
||||||
for configuration in configurations:
|
for configuration in configurations:
|
||||||
if not WorkflowToolParameterConfiguration(**configuration):
|
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||||
raise ValueError("invalid parameter configuration")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||||
"""
|
"""
|
||||||
get workflow graph variables
|
get workflow graph variables
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
|
||||||
:param default_value: the value returned when errors ignored
|
:param default_value: the value returned when errors ignored
|
||||||
:return: an object of the YAML content
|
:return: an object of the YAML content
|
||||||
"""
|
"""
|
||||||
try:
|
if not file_path or not Path(file_path).exists():
|
||||||
with open(file_path, encoding="utf-8") as yaml_file:
|
|
||||||
try:
|
|
||||||
yaml_content = yaml.safe_load(yaml_file)
|
|
||||||
return yaml_content or default_value
|
|
||||||
except Exception as e:
|
|
||||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
if ignore_error:
|
if ignore_error:
|
||||||
return default_value
|
return default_value
|
||||||
else:
|
else:
|
||||||
raise e
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, encoding="utf-8") as yaml_file:
|
||||||
|
try:
|
||||||
|
yaml_content = yaml.safe_load(yaml_file)
|
||||||
|
return yaml_content or default_value
|
||||||
|
except Exception as e:
|
||||||
|
if ignore_error:
|
||||||
|
return default_value
|
||||||
|
else:
|
||||||
|
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,12 @@
|
||||||
from .segment_group import SegmentGroup
|
from .segment_group import SegmentGroup
|
||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
|
ArrayFileSegment,
|
||||||
|
ArrayNumberSegment,
|
||||||
|
ArrayObjectSegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
|
ArrayStringSegment,
|
||||||
|
FileSegment,
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
|
|
@ -15,6 +20,7 @@ from .variables import (
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
ArrayStringVariable,
|
ArrayStringVariable,
|
||||||
|
FileVariable,
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
NoneVariable,
|
NoneVariable,
|
||||||
|
|
@ -46,4 +52,10 @@ __all__ = [
|
||||||
"ArrayNumberVariable",
|
"ArrayNumberVariable",
|
||||||
"ArrayObjectVariable",
|
"ArrayObjectVariable",
|
||||||
"ArraySegment",
|
"ArraySegment",
|
||||||
|
"ArrayFileSegment",
|
||||||
|
"ArrayNumberSegment",
|
||||||
|
"ArrayObjectSegment",
|
||||||
|
"ArrayStringSegment",
|
||||||
|
"FileSegment",
|
||||||
|
"FileVariable",
|
||||||
]
|
]
|
||||||
|
|
@ -5,6 +5,8 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
|
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,6 +41,9 @@ class Segment(BaseModel):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
|
"""
|
||||||
|
Return the size of the value in bytes.
|
||||||
|
"""
|
||||||
return sys.getsizeof(self.value)
|
return sys.getsizeof(self.value)
|
||||||
|
|
||||||
def to_object(self) -> Any:
|
def to_object(self) -> Any:
|
||||||
|
|
@ -99,13 +104,27 @@ class ArraySegment(Segment):
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
items = []
|
items = []
|
||||||
for item in self.value:
|
for item in self.value:
|
||||||
if hasattr(item, "to_markdown"):
|
items.append(str(item))
|
||||||
items.append(item.to_markdown())
|
|
||||||
else:
|
|
||||||
items.append(str(item))
|
|
||||||
return "\n".join(items)
|
return "\n".join(items)
|
||||||
|
|
||||||
|
|
||||||
|
class FileSegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.FILE
|
||||||
|
value: File
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
return self.value.markdown
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||||
value: Sequence[Any]
|
value: Sequence[Any]
|
||||||
|
|
@ -124,3 +143,15 @@ class ArrayNumberSegment(ArraySegment):
|
||||||
class ArrayObjectSegment(ArraySegment):
|
class ArrayObjectSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||||
value: Sequence[Mapping[str, Any]]
|
value: Sequence[Mapping[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFileSegment(ArraySegment):
|
||||||
|
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||||
|
value: Sequence[File]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
items = []
|
||||||
|
for item in self.value:
|
||||||
|
items.append(item.markdown)
|
||||||
|
return "\n".join(items)
|
||||||
|
|
@ -11,5 +11,7 @@ class SegmentType(str, Enum):
|
||||||
ARRAY_NUMBER = "array[number]"
|
ARRAY_NUMBER = "array[number]"
|
||||||
ARRAY_OBJECT = "array[object]"
|
ARRAY_OBJECT = "array[object]"
|
||||||
OBJECT = "object"
|
OBJECT = "object"
|
||||||
|
FILE = "file"
|
||||||
|
ARRAY_FILE = "array[file]"
|
||||||
|
|
||||||
GROUP = "group"
|
GROUP = "group"
|
||||||
|
|
@ -7,6 +7,7 @@ from .segments import (
|
||||||
ArrayNumberSegment,
|
ArrayNumberSegment,
|
||||||
ArrayObjectSegment,
|
ArrayObjectSegment,
|
||||||
ArrayStringSegment,
|
ArrayStringSegment,
|
||||||
|
FileSegment,
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
|
|
@ -73,3 +74,7 @@ class SecretVariable(StringVariable):
|
||||||
class NoneVariable(NoneSegment, Variable):
|
class NoneVariable(NoneSegment, Variable):
|
||||||
value_type: SegmentType = SegmentType.NONE
|
value_type: SegmentType = SegmentType.NONE
|
||||||
value: None = None
|
value: None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileVariable(FileSegment, Variable):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
from .base_workflow_callback import WorkflowCallback
|
||||||
|
from .workflow_logging_callback import WorkflowLoggingCallback
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"WorkflowLoggingCallback",
|
||||||
|
"WorkflowCallback",
|
||||||
|
]
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
|
@ -20,6 +19,8 @@ from core.workflow.graph_engine.entities.event import (
|
||||||
ParallelBranchRunSucceededEvent,
|
ParallelBranchRunSucceededEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .base_workflow_callback import WorkflowCallback
|
||||||
|
|
||||||
_TEXT_COLOR_MAPPING = {
|
_TEXT_COLOR_MAPPING = {
|
||||||
"blue": "36;1",
|
"blue": "36;1",
|
||||||
"yellow": "33;1",
|
"yellow": "33;1",
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||||
|
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||||
|
|
@ -1,52 +1,14 @@
|
||||||
|
from collections.abc import Mapping
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from models import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class NodeType(Enum):
|
class NodeRunMetadataKey(str, Enum):
|
||||||
"""
|
|
||||||
Node Types.
|
|
||||||
"""
|
|
||||||
|
|
||||||
START = "start"
|
|
||||||
END = "end"
|
|
||||||
ANSWER = "answer"
|
|
||||||
LLM = "llm"
|
|
||||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
|
||||||
IF_ELSE = "if-else"
|
|
||||||
CODE = "code"
|
|
||||||
TEMPLATE_TRANSFORM = "template-transform"
|
|
||||||
QUESTION_CLASSIFIER = "question-classifier"
|
|
||||||
HTTP_REQUEST = "http-request"
|
|
||||||
TOOL = "tool"
|
|
||||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
|
||||||
# TODO: merge this into VARIABLE_AGGREGATOR
|
|
||||||
VARIABLE_ASSIGNER = "variable-assigner"
|
|
||||||
LOOP = "loop"
|
|
||||||
ITERATION = "iteration"
|
|
||||||
ITERATION_START = "iteration-start" # fake start node for iteration
|
|
||||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
|
||||||
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "NodeType":
|
|
||||||
"""
|
|
||||||
Get value of given node type.
|
|
||||||
|
|
||||||
:param value: node type value
|
|
||||||
:return: node type
|
|
||||||
"""
|
|
||||||
for node_type in cls:
|
|
||||||
if node_type.value == value:
|
|
||||||
return node_type
|
|
||||||
raise ValueError(f"invalid node type value {value}")
|
|
||||||
|
|
||||||
|
|
||||||
class NodeRunMetadataKey(Enum):
|
|
||||||
"""
|
"""
|
||||||
Node Run Metadata Key.
|
Node Run Metadata Key.
|
||||||
"""
|
"""
|
||||||
|
|
@ -70,7 +32,7 @@ class NodeRunResult(BaseModel):
|
||||||
|
|
||||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||||
|
|
||||||
inputs: Optional[dict[str, Any]] = None # node inputs
|
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||||
process_data: Optional[dict[str, Any]] = None # process data
|
process_data: Optional[dict[str, Any]] = None # process data
|
||||||
outputs: Optional[dict[str, Any]] = None # node outputs
|
outputs: Optional[dict[str, Any]] = None # node outputs
|
||||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||||
|
|
@ -79,24 +41,3 @@ class NodeRunResult(BaseModel):
|
||||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||||
|
|
||||||
error: Optional[str] = None # error message if status is failed
|
error: Optional[str] = None # error message if status is failed
|
||||||
|
|
||||||
|
|
||||||
class UserFrom(Enum):
|
|
||||||
"""
|
|
||||||
User from
|
|
||||||
"""
|
|
||||||
|
|
||||||
ACCOUNT = "account"
|
|
||||||
END_USER = "end-user"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "UserFrom":
|
|
||||||
"""
|
|
||||||
Value of
|
|
||||||
:param value: value
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for item in cls:
|
|
||||||
if item.value == value:
|
|
||||||
return item
|
|
||||||
raise ValueError(f"Invalid value: {value}")
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -7,4 +9,4 @@ class VariableSelector(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
value_selector: list[str]
|
value_selector: Sequence[str]
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue