Compare commits

...

372 Commits

Author SHA1 Message Date
Yeuoly
28edbbac0b
Plugins/bump to 1.0.0 beta.1 (#12568) 2025-01-09 22:46:24 +08:00
Yeuoly
782abcecd8
bump version to 1.0.0-beta.1 (#12567) 2025-01-09 22:38:20 +08:00
Yeuoly
4deb02fc2c
fix: rename plugin db name to dify_plugin (#12565) 2025-01-09 21:56:24 +08:00
Joel
f967180dc2
fix: not show stragry type (#12561) 2025-01-09 20:55:17 +08:00
Yeuoly
cead13cbc3
plugins: remove middleware.1.yaml (#12559) 2025-01-09 20:34:49 +08:00
Yeuoly
078c151065
fix: add-default-console-url (#12558) 2025-01-09 20:34:13 +08:00
Yeuoly
17babca362
Introducing: Plugin Mechanism (#12553) 2025-01-09 19:54:17 +08:00
AkaraChen
8efed8858c
feat: reset parameters when switch agent strategy (#12549) 2025-01-09 19:31:02 +08:00
Yeuoly
0d411a0b5a
feat: refactor docker-compose (#12550) 2025-01-09 19:08:11 +08:00
Yeuoly
13f0c01f93
feat: add ci checks to plugins/beta branch (#12542)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2025-01-09 18:57:09 +08:00
zxhlyh
3c014f3ae5
Feat/plugins (#12547)
Co-authored-by: AkaraChen <akarachen@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: kurokobo <kuro664@gmail.com>
Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com>
2025-01-09 18:47:41 +08:00
Yeuoly
e4c4490175 refactor 2025-01-09 17:27:05 +08:00
Yeuoly
94a62f6b4e enhancement: remove unrequired deps 2025-01-09 17:06:38 +08:00
Novice Lee
d76af08784 feat: add agent log icon 2025-01-09 16:55:17 +08:00
Yeuoly
f748d6c7c4 fix: mypy issues 2025-01-09 16:53:30 +08:00
Yeuoly
76e24d91c0 fix: migrations 2025-01-09 13:30:43 +08:00
Novice Lee
5ce4ddc0ed fix: change the agent strategy category 2025-01-09 11:13:00 +08:00
Novice Lee
491d641485 feat: add agent node log 2025-01-09 08:32:32 +08:00
Yeuoly
172c5f19cc fix: formatter 2025-01-08 21:11:58 +08:00
Yeuoly
b7d168ac59 fix: mypy linter 2025-01-08 21:11:42 +08:00
Yeuoly
fb309462ad Merge branch 'main' into fix/chore-fix 2025-01-08 20:36:22 +08:00
Novice Lee
b56d2b739b feat: add fc agent mode support 2025-01-08 07:41:17 +08:00
Yeuoly
fb7b2c8ff3 fix: backwards invoke nodes 2025-01-07 20:52:25 +08:00
Yeuoly
c3440a27fb fix 2025-01-07 18:59:13 +08:00
Yeuoly
ff3d3f71fb fix: use host.docker.internal as the default plugin daemon middleware endpoint 2025-01-07 14:56:03 +08:00
Yeuoly
9685b9a302 refactor: docker-compose-middleware.yaml 2025-01-07 14:44:08 +08:00
Yeuoly
07c7b7b886 fix: remove 5002 port from docker mapping 2025-01-06 21:45:44 +08:00
kurokobo
8d75abc976
fix: correct fetch_from for customizable models (#12400) 2025-01-06 21:16:39 +08:00
Yeuoly
aa6452b3bf fix: use session to manage AppSite 2025-01-06 21:12:50 +08:00
Yeuoly
3799d40937 feat: support docker deployment for plugin 2025-01-06 20:28:50 +08:00
Yeuoly
d2ff8a2381 fix: bugs 2025-01-06 14:59:40 +08:00
Yeuoly
5f51a19de2 fix: allow meta to be None 2025-01-03 14:48:19 +08:00
Yeuoly
71e0bfcbd8 fix: updating tool credentials does not works as expected 2025-01-03 14:09:17 +08:00
Yeuoly
d815c74fc5 fix: ruff 2024-12-31 16:48:20 +08:00
Yeuoly
107e44c8fb Merge branch 'main' into fix/chore-fix 2024-12-31 16:47:56 +08:00
Yeuoly
adf7eea7fe fix: ruff 2024-12-31 16:40:26 +08:00
Yeuoly
6e73ad2fc6 feat: plugin migrations 2024-12-31 16:38:02 +08:00
Yeuoly
06412b37d3 fix: no attribbute identity 2024-12-30 21:14:24 +08:00
Yeuoly
63665a5ff1 feat: add conversation_id to invoke 2024-12-30 13:41:54 +08:00
Yeuoly
05a43e3e80 fix: rebaseing to main 2024-12-30 13:34:45 +08:00
Yeuoly
83fdb42520 fix: variable message 2024-12-27 22:56:39 +08:00
Yeuoly
cbf405beea fix: remigrate 2024-12-27 18:37:34 +08:00
Yeuoly
af2aede783 feat: support precision to PluginParameter 2024-12-27 18:07:28 +08:00
Yeuoly
e359ace633 fix: add agent logs 2024-12-27 17:55:41 +08:00
Yeuoly
a5555f90c6 fix: models 2024-12-27 17:34:17 +08:00
Yeuoly
78664c8903 Merge branch 'main' into fix/chore-fix 2024-12-27 17:33:58 +08:00
Yeuoly
45070535bd fix: linter 2024-12-27 14:47:48 +08:00
Yeuoly
048e8cf0d1 fix: remove validate credentials 2024-12-27 12:16:58 +08:00
Yeuoly
598d208e54 fix: agent error handling 2024-12-27 12:09:39 +08:00
Yeuoly
8102cee8df fix: unbound reference 2024-12-27 11:33:04 +08:00
Yeuoly
c9eb9c14d7 fix: block call to flask_app 2024-12-26 22:58:34 +08:00
Yeuoly
e77cd87842 fix: linter 2024-12-26 22:30:22 +08:00
Yeuoly
ac5e3caebc optimize: migrate speed 2024-12-26 22:30:06 +08:00
Yeuoly
23066a9ba8 feat: support extracting plugins into local files 2024-12-26 18:05:14 +08:00
Yeuoly
0249f15609 fix: linter 2024-12-26 17:39:21 +08:00
Yeuoly
2f523dd29f optimize: add friendly logs 2024-12-26 17:39:13 +08:00
Yeuoly
b34d815883 feat: support auto generate and template 2024-12-26 17:25:56 +08:00
Yeuoly
51cc63d9ce fix: undefined dereference to ApiTool 2024-12-26 14:12:43 +08:00
Yeuoly
430af95b53 fix: linter 2024-12-26 14:07:29 +08:00
Yeuoly
0164d1410a migrations for plugins 2024-12-26 14:07:12 +08:00
Yeuoly
cbc5045b7a fix: ruff formatter 2024-12-26 13:23:56 +08:00
Yeuoly
b980c07af8 fix: ruff formatter 2024-12-26 13:22:18 +08:00
Yeuoly
e231cf2c48 fix: errors occrus during rebasing 2024-12-26 13:20:12 +08:00
Yeuoly
80d8e47e42 fix: skip json transforming if error occurs 2024-12-25 18:23:31 +08:00
Yeuoly
fee4dd7d7a fix: unused stream variable 2024-12-25 15:32:59 +08:00
Yeuoly
00cf5f3841 fix: linters 2024-12-25 15:18:29 +08:00
Yeuoly
9ee0c7a694 merge 2024-12-25 14:39:15 +08:00
Yeuoly
6ee7ca1890 fix: add specific exceptions 2024-12-24 22:00:45 +08:00
Yeuoly
f589397f25 fix: import Optional 2024-12-24 21:56:55 +08:00
Yeuoly
ee080dddf9 fix: rebase 2024-12-24 21:48:49 +08:00
Yeuoly
ee6841648c fix: migrations and imports recycle 2024-12-24 21:36:42 +08:00
Yeuoly
5a57dad93c fix: linter 2024-12-24 21:29:24 +08:00
Yeuoly
4199998c7e Merge branch 'main' into fix/chore-fix 2024-12-24 21:28:56 +08:00
Yeuoly
39656f7f84 fix: linter and formatter 2024-12-24 18:38:34 +08:00
Yeuoly
bf39e314d8 fix: add install count 2024-12-24 18:38:12 +08:00
Yeuoly
8cc4c109d0
fix: return types of builtin tools 2024-12-19 01:09:15 +08:00
Yeuoly
a1cdca02e3
fix: formatter 2024-12-19 01:02:44 +08:00
Yeuoly
1b21d7513d
fix: reduce model provider fetchs 2024-12-19 01:02:08 +08:00
takatost
d5c708c62b feat: add plugin_model_providers context 2024-12-19 00:50:46 +08:00
Yeuoly
342d4060ff
fix: add additional parameters to exists tools 2024-12-18 23:54:48 +08:00
Yeuoly
05232d36f0
fix: add default values to WorkflowAppGenerator 2024-12-17 15:49:33 +08:00
Yeuoly
636dde94c7
fix: migrations 2024-12-16 14:17:39 +08:00
Yeuoly
75fe785d88
Merge branch 'main' into fix/chore-fix 2024-12-16 14:08:18 +08:00
Yeuoly
a61da6cf95
fix: replace Enum with StrEnum 2024-12-16 13:40:02 +08:00
Yeuoly
93c3699128
feat: add label to agent log 2024-12-15 18:12:29 +08:00
Yeuoly
6357450a7a
feat: support hidden parameters 2024-12-13 22:53:08 +08:00
Yeuoly
6339706c68
fix: ruff reformatter 2024-12-13 19:51:09 +08:00
Yeuoly
65a4cb769b
refactor: tool entities 2024-12-13 19:50:54 +08:00
Yeuoly
63206a7967
fix: incorrect use of node execution id 2024-12-13 00:05:57 +08:00
Yeuoly
9a6f120e5c
feat: support agent log event 2024-12-12 23:46:26 +08:00
Yeuoly
dedc1b0c3a
refactor: agent strategy parameter 2024-12-12 19:16:06 +08:00
Yeuoly
46bb246ecc
refactor: rename agent to agent strategy 2024-12-12 18:27:43 +08:00
Yeuoly
3c628d0c26
refactor: rename agent to agent strategy 2024-12-12 18:27:31 +08:00
Yeuoly
c2983ecbb7
fix: rename stream to streaming 2024-12-12 13:50:34 +08:00
Yeuoly
527c1cf608
fix: deduplicate provider id 2024-12-10 02:21:46 +08:00
Yeuoly
93786f516c
apply ruff 2024-12-10 00:22:54 +08:00
Yeuoly
a175d6b2d7
feat: agent management 2024-12-10 00:22:41 +08:00
Yeuoly
296fd82bbf
fix: agent node 2024-12-09 23:26:16 +08:00
Yeuoly
4ccd571364
fix: ruff 2024-12-09 23:02:25 +08:00
Yeuoly
ae72514cb4
feat: support agent node 2024-12-09 23:02:11 +08:00
Yeuoly
16b49ac436
Merge branch 'main' into fix/chore-fix 2024-12-09 16:08:19 +08:00
Yeuoly
c377eb8c28
fix: unbound variable in tool node 2024-12-09 15:43:01 +08:00
Yeuoly
337eff2b79
Merge branch 'main' into fix/chore-fix 2024-12-06 16:45:25 +08:00
Yeuoly
b7ac287fec
fix: use default_factory for list fields 2024-12-05 20:57:30 +08:00
Yeuoly
c1a85b0208
fix: add default value to plugin permission field 2024-12-05 14:48:34 +08:00
Yeuoly
01efdee1dd
fix: support other file types for Tool 2024-12-04 19:26:01 +08:00
Yeuoly
0af9c4fd9d
chore: reformat 2024-12-04 19:02:28 +08:00
Yeuoly
ee38bd8817
refactor: check dependencies 2024-12-04 19:01:54 +08:00
Yeuoly
86291c13e4
Merge branch 'main' into fix/chore-fix 2024-12-04 15:34:39 +08:00
Yeuoly
7679a57f18
fix: agent type errors 2024-12-03 19:44:57 +08:00
Yeuoly
dcf19549cb
feat: move audio and webscraper back to dify 2024-12-03 19:27:57 +08:00
Yeuoly
574a6c1ded
fix: add extension, filename and size to PluginFileEntity 2024-12-03 16:51:51 +08:00
Yeuoly
c34877aecf
fix: update tool provider credentials 2024-12-03 16:28:36 +08:00
Yeuoly
632b2bac2a
fix: invoke-email 2024-12-02 21:59:52 +08:00
Yeuoly
77a62f33b3
fix: Lookup errors for contextvars used in ToolManager 2024-12-02 21:25:47 +08:00
Yeuoly
ad899844a1
fix: workflow loads tool provider icon 2024-12-02 21:08:36 +08:00
Yeuoly
b10d6051ba
fix: summary and create_file_by_url 2024-12-02 16:51:37 +08:00
Yeuoly
fb44cd87e7
fix: image url message 2024-11-29 18:20:36 +08:00
Yeuoly
89af726985
fix: cot agent 2024-11-29 16:48:39 +08:00
Yeuoly
6f2d5ff099
fix: add tenant_id to invoke tts 2024-11-29 15:59:07 +08:00
Yeuoly
687455ca31
fix: tool file id 2024-11-29 14:09:34 +08:00
Yeuoly
8c5928da2f
fix: unify error handling 2024-11-28 20:44:06 +08:00
Yeuoly
772009115d
fix: keep process_data with None if not 2024-11-28 19:35:30 +08:00
Yeuoly
0452dfd029
fix: missing tool invoke messages 2024-11-28 19:09:04 +08:00
Yeuoly
eead6abe85
fix: tool image url response 2024-11-28 18:23:28 +08:00
Yeuoly
5c6d919a4a
fix: handle detailed error type 2024-11-28 17:12:29 +08:00
Yeuoly
e39eddab03
fix: change to use convert_stream_full_response 2024-11-27 14:48:44 +08:00
Yeuoly
db726e02a0
feat: support multi token count 2024-11-26 18:59:03 +08:00
Yeuoly
e4b8220bc2
Merge branch 'main' into fix/chore-fix 2024-11-26 18:02:41 +08:00
Yeuoly
08cfcb453c
fix: missing marshal fields of leaked+dependencies 2024-11-26 13:59:52 +08:00
Yeuoly
992e1eedde
fix: export agent dsl 2024-11-25 23:36:19 +08:00
Yeuoly
c2ce8e638e
fix: deleted_tools 2024-11-25 23:22:17 +08:00
Yeuoly
ba3659a792
feat: support delete all install tasks 2024-11-25 17:11:41 +08:00
Yeuoly
965fabd578
fix: rename dependencies 2024-11-25 16:57:38 +08:00
Yeuoly
accbbae755
cleanup: remove get_interates 2024-11-25 16:47:49 +08:00
Yeuoly
49bd1a7a49
fix: riff 2024-11-25 16:44:08 +08:00
Yeuoly
5ff9cee326
Merge branch 'main' into fix/chore-fix 2024-11-25 15:37:19 +08:00
Yeuoly
200f9af5d8
optimize error messages 2024-11-22 20:04:20 +08:00
Yeuoly
1443fd6739
optimize: indexing-estimate 2024-11-22 19:39:07 +08:00
Yeuoly
e63ae36665
fix 2024-11-22 18:19:02 +08:00
Yeuoly
cfa7c89dfe
refactor: text-embedding interfaces to returns list[int] 2024-11-22 18:09:33 +08:00
Yeuoly
a6835ac64d
fix: add detailed error messages 2024-11-21 17:00:00 +08:00
Yeuoly
a700b49461
fix: migration 2024-11-21 13:55:08 +08:00
Yeuoly
22df86fe8a
fix: ruff 2024-11-21 13:53:08 +08:00
Yeuoly
24734009b9
Merge branch 'main' into fix/chore-fix 2024-11-21 13:52:28 +08:00
Yeuoly
959d060a44
fix: remove signature verify 2024-11-21 00:30:28 +08:00
Yeuoly
4492295683
fix: remove plugin files 2024-11-20 18:12:12 +08:00
Yeuoly
88fac0d898
fix: add tenant_id to plugin upload files url 2024-11-19 16:50:14 +08:00
Yeuoly
8b30099672
fix: convert backwards invocation into BaseBackwardsResponse 2024-11-19 14:03:40 +08:00
Yeuoly
97a3727962
fix: optimize DEFAULT-USER 2024-11-18 17:21:17 +08:00
Yeuoly
2cb640de15
refactor: load tools cache 2024-11-15 19:53:50 +08:00
Yeuoly
fb4ee813c7
fix: agent 2024-11-15 18:37:33 +08:00
Yeuoly
6300e506fb
fix: rag 2024-11-15 15:54:14 +08:00
Yeuoly
a0543ab8fb
Merge branch 'main' into fix/chore-fix 2024-11-15 15:43:32 +08:00
Yeuoly
634cb6233e
feat: sypport batch fetch plugin installations 2024-11-15 00:47:25 +08:00
Yeuoly
db68ae4a73
feat: support upload bundle 2024-11-14 22:58:57 +08:00
Yeuoly
d25e79e794
feat: support uploading images through plugin 2024-11-14 18:32:51 +08:00
Yeuoly
183b943803
feat: support check dependencies through url 2024-11-13 15:19:20 +08:00
Yeuoly
5828abcd62
fix: uses to check if the tools are already loaded 2024-11-12 21:43:19 +08:00
Yeuoly
56bd0dedfe
fix: incorrect paths to upgrade plugins 2024-11-12 20:48:28 +08:00
Yeuoly
f6136427a4
feat: export dsl with dependencies 2024-11-12 19:50:56 +08:00
Yeuoly
21fd58caf9
Merge branch 'fix/chore-fix' of github.com:langgenius/dify into fix/chore-fix 2024-11-12 18:53:45 +08:00
Yeuoly
9a69d03fbe
feat: add icon and labels to plugin install task 2024-11-11 20:59:31 +08:00
takatost
1d2118fc5d fix: hosted moderation 2024-11-11 20:31:11 +08:00
takatost
bc0724b499 chore: fix typo 2024-11-11 19:50:39 +08:00
Yeuoly
5cdbfe2f41
Merge branch 'main' into fix/chore-fix 2024-11-11 14:00:53 +08:00
Yeuoly
5fd82084f9
fix: avoid empty plugin entity 2024-11-11 13:30:11 +08:00
takatost
f0637ba332 fix: create basic app causing internal error when default model is not exist 2024-11-08 23:09:52 +08:00
takatost
115c9486c3 fix hosted issues 2024-11-08 19:23:49 +08:00
Yeuoly
8b5231b7ee
fix: invalid key of marketplace response 2024-11-08 17:27:16 +08:00
Yeuoly
38cae29757
fix: wrap marketplace apis with try catch 2024-11-08 17:20:54 +08:00
Yeuoly
7a2b2a04c9
Merge branch 'main' into fix/chore-fix 2024-11-08 13:47:24 +08:00
Yeuoly
fe677cc5f9
Merge branch 'main' into fix/chore-fix 2024-11-07 17:06:29 +08:00
Yeuoly
28c9ec3f4f
feat: support fetch tool provider info 2024-11-06 17:30:50 +08:00
Yeuoly
6baa98f166
feat: support app-selector, model-selector and tool-selector as parameters 2024-11-06 17:13:05 +08:00
Yeuoly
e9d69f020a
feat: cast files into correct type while invoking 2024-11-05 20:30:13 +08:00
Novice
3c89d45a2d
fix: iteration none output error (#10295) 2024-11-05 20:30:13 +08:00
-LAN-
baab81714e
fix(http_request): improve parameter initialization and reorganize tests (#10297) 2024-11-05 20:30:13 +08:00
Matsuda
507bb3549a
fix typo: writeOpner to writeOpener (#10290) 2024-11-05 20:30:13 +08:00
pinsily
2d1e5fb4e0
fix: handle KeyError when accessing rules in CleanProcessor.clean (#10258) 2024-11-05 20:30:12 +08:00
eux
b9198639e2
fix: borken faq url in CONTRIBUTING.md (#10275) 2024-11-05 20:30:12 +08:00
非法操作
43c7739b88
feat: add xAI model provider (#10272) 2024-11-05 20:30:12 +08:00
Matsuda
f65d577f54
fix(model_runtime): fix wrong max_tokens for Claude 3.5 Haiku on Amazon Bedrock (#10286) 2024-11-05 20:30:00 +08:00
-LAN-
b88145096f
feat(model): add validation for custom disclaimer length (#10287) 2024-11-05 20:30:00 +08:00
-LAN-
33219e850a
fix(node): correct file property name in function switch (#10284) 2024-11-05 20:30:00 +08:00
NFish
3040d538f7
refactor the logic of refreshing access_token (#10068) 2024-11-05 20:30:00 +08:00
github-actions[bot]
4e1af81e11
chore: translate i18n files (#10273)
Co-authored-by: laipz8200 <16485841+laipz8200@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-11-05 20:30:00 +08:00
Benjamin
56e19fd8f5
Updates: Add mplfonts library for customizing matplotlib fonts and Va… (#9903) 2024-11-05 20:30:00 +08:00
Novice
d330d31ee5
feat: Iteration node support parallel mode (#9493) 2024-11-05 20:29:59 +08:00
GeorgeCaoJ
0858108423
fix(workflow): handle else condition branch addition error in if-else node (#10257) 2024-11-05 20:29:59 +08:00
-LAN-
2cd976846a
feat(document_extractor): support tool file in document extractor (#10217) 2024-11-05 20:29:59 +08:00
Matsuda
5d2c88ef59
feat: support Claude 3.5 Haiku on Amazon Bedrock (#10265) 2024-11-05 20:29:59 +08:00
-LAN-
fe3cde973e
refactor(parameter_extractor): implement custom error classes (#10260) 2024-11-05 20:29:59 +08:00
-LAN-
794f495ef2
fix(validation): allow to use 0 in the inputs form (#10255) 2024-11-05 20:29:32 +08:00
-LAN-
0dda682033
chore(Dockerfile): upgrade zlib arm64 (#10244) 2024-11-05 20:29:31 +08:00
方程
01d8d10f1c
Using a dedicated interface to obtain the token credential for the gitee.ai provider (#10243) 2024-11-05 20:29:12 +08:00
-LAN-
c711c5e36e
feat(workflow): add configurable workflow file upload limit (#10176)
Co-authored-by: JzoNg <jzongcode@gmail.com>
2024-11-05 20:29:09 +08:00
shisaru292
1e27557865
fix: missing working directory parameter in script (#10226) 2024-11-05 20:28:29 +08:00
-LAN-
2d9632d8b9
refactor(list_operator): replace ValueError with InvalidKeyError (#10222) 2024-11-05 20:28:29 +08:00
-LAN-
7e42de1e7b
refactor(workflow): introduce specific error handling for LLM nodes (#10221) 2024-11-05 20:28:29 +08:00
-LAN-
bd674d27be
refactor(http_request): add custom exception handling for HTTP request nodes (#10219) 2024-11-05 20:28:29 +08:00
-LAN-
5735761920
refactor(workflow): introduce specific exceptions for code validation (#10218) 2024-11-05 20:28:29 +08:00
-LAN-
405b704f02
chore(llm_node): remove unnecessary type ignore for context assignment (#10216) 2024-11-05 20:28:29 +08:00
Jyong
f38abaaa6a
fix the ssrf of docx file extractor external images (#10237) 2024-11-05 20:28:28 +08:00
Hanqing Zhao
c8a5fee622
Modify translation (#10213) 2024-11-05 20:28:28 +08:00
Jiang
fe1c0ac602
Add Lindorm as a VDB choice (#10202)
Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
2024-11-05 20:28:28 +08:00
crazywoola
e79c3e4531
Fix/10199 application error a client side exception has occurred see the browser console for more information (#10211) 2024-11-05 20:28:28 +08:00
-LAN-
3ea3df7189
refactor(validation): improve input validation logic (#10175) 2024-11-05 20:28:28 +08:00
-LAN-
b01e7d778e
chore(list_operator): refine exception handling for error specificity (#10206) 2024-11-05 20:28:28 +08:00
-LAN-
7c45859594
fix(document_extractor): update base exception class (#10208) 2024-11-05 20:28:27 +08:00
Kota-Yamaguchi
aa9fd76072
Feat : add LLM model indicator in prompt generator (#10187) 2024-11-05 20:28:27 +08:00
Kota-Yamaguchi
e7d947379f
chore : code generator preview hint (#10188) 2024-11-05 20:28:17 +08:00
zxhlyh
8cd386f2c1
fix: webapp upload file (#10195) 2024-11-05 20:28:17 +08:00
-LAN-
987e1b9ced
fix(api): replace current_user with end_user in file upload (#10194) 2024-11-05 20:28:17 +08:00
-LAN-
81a77d0623
feat(document_extractor): integrate unstructured API for PPTX extraction (#10180) 2024-11-05 20:28:17 +08:00
Cling_o3
ac1f93e3d5
[fix] fix the bug that modify document name not effective (#10154) 2024-11-05 20:27:58 +08:00
-LAN-
0d5c0b4fe4
fix(workflow model): ensure consistent timestamp updating (#10172) 2024-11-05 20:27:57 +08:00
jiangbo721
d1c480a7d8
fix: Cannot find declaration to go to CLEAN_DAY_SETTING (#10157)
Co-authored-by: 刘江波 <liujiangbo1@xiaomi.com>
2024-11-05 20:27:57 +08:00
Lawrence Li
007b561e32
feat: add gpustack model provider (#10158) 2024-11-05 20:27:57 +08:00
takatost
c100f24f7d compatible model daemon request exception 2024-11-01 19:20:26 -07:00
takatost
d92cb994a9 fix voice list 2024-11-01 01:56:15 -07:00
Yeuoly
413326905e
rebase migrations 2024-11-01 16:55:07 +08:00
Yeuoly
5605ff9803
fix voice list 2024-11-01 16:42:32 +08:00
Yeuoly
84b7a4607a
fix: setup_required 2024-11-01 16:28:17 +08:00
Yeuoly
10cc4e758c
Merge branch 'main' into fix/chore-fix 2024-11-01 16:23:04 +08:00
Yeuoly
8070be9b76
fix: missing 'follow_redirects' argument while download plugin packages 2024-11-01 15:38:45 +08:00
Yeuoly
f1f1baae9c
feat: support plugin tags 2024-11-01 15:07:11 +08:00
takatost
f20c9ef763 fix 2024-11-01 00:01:05 -07:00
takatost
f798add31c compatible with original provider name 2024-11-01 00:00:53 -07:00
Yeuoly
8c2dbe876f
fix: custom tool parser 2024-11-01 14:26:56 +08:00
Yeuoly
6fd0a55b00
fix: correct dockerfile dependencies 2024-10-31 15:32:25 +08:00
Yeuoly
bb58f5c6e5
fix: avoid None to be assigned to WorkflowToolProviderController as provider id 2024-10-31 15:18:45 +08:00
takatost
18edeb8e0a integrate model provider with plugin daemon 2024-10-30 18:56:52 -07:00
Yeuoly
459cb9dd72
fix: transform plugin icon incorrect 2024-10-30 16:09:17 +08:00
Yeuoly
f9e2c738b0
fix: permission change api should not wraps a permission decorator 2024-10-29 17:16:32 +08:00
Yeuoly
739e15f88b
feat: support tool plugin id 2024-10-29 12:32:11 +08:00
Yeuoly
5bf86ff66d
feat: support latest package identifier 2024-10-28 15:56:15 +08:00
Yeuoly
c657378d06
feat: support plugin permission management 2024-10-28 15:54:34 +08:00
Yeuoly
685e8cdc7d
refactor: document segment query 2024-10-28 15:07:33 +08:00
Yeuoly
d36dece0af
feat: support upgrade interfaces 2024-10-25 18:56:38 +08:00
Yeuoly
5f61aa85db
feat: add latest version 2024-10-25 13:52:33 +08:00
Yeuoly
e5837b88e0
fix: add subpath 2024-10-25 13:26:32 +08:00
Yeuoly
ffdc6f5c60
feat: support remove single item from installation task 2024-10-25 13:22:37 +08:00
Yeuoly
99c8f364ae
fix: temp fix for empty redis password 2024-10-24 13:20:26 +08:00
Yeuoly
a0a1243c90
cleanup: remove hacked code 2024-10-22 17:56:13 +08:00
Yeuoly
b916b4064a
Merge remote-tracking branch 'origin/fix/tool-use-file' into fix/chore-fix 2024-10-22 17:47:01 +08:00
Yeuoly
dea2962a79
Merge main into feat/plugin 2024-10-22 17:35:11 +08:00
Yeuoly
1450e5d5cb
feat: add supports for multimodal 2024-10-22 17:26:00 +08:00
Joel
43a2d4335b fix: tool use file caused error 2024-10-22 16:51:11 +08:00
Yeuoly
11270a7ef2
Migrate to DeclarativeBaseModel 2024-10-21 20:38:27 +08:00
Yeuoly
53e1b45d40
fix: remove .query reference of db.Model 2024-10-21 20:23:27 +08:00
Yeuoly
bedbd658fe
Merge main into fix/chore-fix 2024-10-21 20:01:49 +08:00
Yeuoly
7b62b5578e
refactor: add manifest into upload interfaces 2024-10-21 18:48:03 +08:00
Yeuoly
ccbe42eb5f
feat: add plugin id into tool api entities 2024-10-17 20:46:29 +08:00
Yeuoly
45f8651a3d
feat: support backwards invoke summary 2024-10-17 19:44:30 +08:00
Yeuoly
7754431a34
feat: support plugin max package size 2024-10-17 18:44:16 +08:00
Yeuoly
fa7215cfea
Merge branch 'main' into fix/chore-fix 2024-10-17 13:46:43 +08:00
Yeuoly
678c89891a
feat: support verified 2024-10-17 13:40:33 +08:00
Yeuoly
beebcbd962
feat: add description 2024-10-17 12:59:11 +08:00
Yeuoly
8495ed3348
add conversation id, app id and message id into plugin session 2024-10-16 15:10:50 +08:00
Yeuoly
31cca4a849
fix: add marketplace switch 2024-10-16 14:47:48 +08:00
Yeuoly
43ffccc8fd
fix: install plugins 2024-10-16 14:02:05 +08:00
Yeuoly
a81293cf5a
feat: add category for plugins 2024-10-16 13:03:50 +08:00
Yeuoly
276701e1b7
refactor: plugin installation 2024-10-14 17:52:29 +08:00
Yeuoly
8e1cf3233c
fix: missing openai moderation 2024-10-14 16:42:36 +08:00
Yeuoly
dd551e6ca8
Ruff: reformatter 2024-10-14 16:25:51 +08:00
Yeuoly
ae1eeb9b2a
Mergin main into fix/chore-fix 2024-10-14 16:22:12 +08:00
Junyan Qin
b58f8dd7b4
feat: download pkg from marketplace (#9184) 2024-10-11 02:00:02 +08:00
Yeuoly
118fa66567
feat: backwards invoke tools 2024-10-10 18:09:06 +08:00
Yeuoly
699d41deec
fix: add source to plugin entity 2024-10-10 16:47:25 +08:00
Yeuoly
dd0462c1dc
feat: support two install source 2024-10-10 16:35:36 +08:00
Yeuoly
a470e0e60e
fix: missing detailed paths of endpoints 2024-10-10 00:12:46 +08:00
Yeuoly
2622159763
feat: support verify signature 2024-10-09 23:13:01 +08:00
Yeuoly
dfaf639790
feat: support endpoint url template 2024-10-09 22:58:36 +08:00
Yeuoly
ae96f66a08
feat: support list endpoints for single plugin, fix: failed to clear endpoint credentials 2024-10-09 22:33:18 +08:00
Yeuoly
570b7d18ac
fix: endpoint apis 2024-10-08 23:48:38 +08:00
Yeuoly
a9c21ef929
feat: uninstall plugins 2024-10-08 22:38:33 +08:00
Yeuoly
e27a03ae15
feat: support install plugin 2024-10-08 21:28:59 +08:00
Yeuoly
56b7853afe
feat: compat tool provider credentials to updated data 2024-09-30 23:22:23 +08:00
takatost
e12f4009d3 feat: optimize icon url 2024-09-30 17:46:40 +08:00
Yeuoly
6dfc31a542
refactor: credentials schemas to array 2024-09-30 17:39:13 +08:00
Yeuoly
c9f80b46a1
fix: add endpoint name 2024-09-30 16:57:09 +08:00
Yeuoly
0025b27200
fix: tool invocation logs 2024-09-29 21:09:01 +08:00
Yeuoly
0dd05d7b6d
feat: tool output schema 2024-09-29 20:58:07 +08:00
takatost
7c83d5ce76 feat: add dockerignore items 2024-09-29 20:16:21 +08:00
takatost
a57f60a6e0 feat: remove unused codes 2024-09-29 19:47:47 +08:00
Yeuoly
2f36692bf9
fix: get tool runtime parameters 2024-09-29 19:37:03 +08:00
takatost
bcdb407be8 feat: remove unused codes 2024-09-29 18:24:33 +08:00
Yeuoly
d4e007f9db
feat: support get tool runtime parameters 2024-09-29 18:19:03 +08:00
takatost
8563155d1b feat: remove unused codes 2024-09-29 18:18:01 +08:00
takatost
8236373498 feat: remove unused codes 2024-09-29 18:16:21 +08:00
Yeuoly
196bfeaaf4
Merge branch 'main' into fix/chore-fix 2024-09-29 17:14:10 +08:00
Yeuoly
957ab093c9
enhancement: reduce requests to plugin daemon 2024-09-29 17:07:40 +08:00
Yeuoly
e9e5c8806a
refactor: using DeclarativeBase as parent class of models, refactored tools 2024-09-29 17:00:58 +08:00
Yeuoly
c8bc3892b3
refactor: invoke tool from dify 2024-09-29 14:44:22 +08:00
Yeuoly
735e57b73a
fix: transform generic error message into correct type 2024-09-29 13:46:16 +08:00
Yeuoly
635a53ea38
fix: import undefined types 2024-09-29 13:23:14 +08:00
Yeuoly
7b76b1ff82
Merge fix/chore-fix into fix/chore-fix 2024-09-29 13:12:22 +08:00
takatost
47c8824be6 feat: move model request to plugin daemon 2024-09-29 00:15:17 +08:00
takatost
1c3213184e feat: move model request to plugin daemon 2024-09-29 00:15:14 +08:00
Yeuoly
d9cced8419
Merge branch 'main' into fix/chore-fix 2024-09-28 20:18:28 +08:00
Yeuoly
c3359a9291
refactor: using plugin id to dispatch request instead 2024-09-27 21:48:48 +08:00
Yeuoly
2da32e49d0
fix: tests 2024-09-26 17:51:13 +08:00
Yeuoly
1837692a66
fix: sse error message 2024-09-26 17:40:27 +08:00
Yeuoly
5dcd25a613
fix: missing error message 2024-09-26 17:22:39 +08:00
Yeuoly
507fff0259
fix: tts file was deleted before invocation 2024-09-26 15:47:16 +08:00
Yeuoly
0ad9dbea63
feat: backwards invoke model 2024-09-26 15:38:22 +08:00
Yeuoly
4c28034224
refactor: encryption 2024-09-26 14:51:10 +08:00
Yeuoly
1d575524c3
fix: missing user id 2024-09-26 14:20:05 +08:00
Yeuoly
dc255cc154
Merge main into feat/plugin 2024-09-26 12:59:06 +08:00
Yeuoly
ea497f828f
feat: endpoint management 2024-09-26 12:49:00 +08:00
Yeuoly
153dc5b3f3
feat: endpoint apis 2024-09-26 10:26:45 +08:00
Yeuoly
a91951b374
feat: invoke node 2024-09-24 20:15:13 +08:00
Yeuoly
68c10a1672
feat: add backwards invoke node api 2024-09-24 18:03:48 +08:00
Yeuoly
592f85f7a9
formatter 2024-09-24 16:40:42 +08:00
Yeuoly
cda9f6ec6b
Merge main into fix/chore-fix 2024-09-24 16:38:38 +08:00
Yeuoly
64706c709c
fix 2024-09-24 16:35:01 +08:00
Yeuoly
9722e6bcb1
fix: allow duplicate tool providers 2024-09-24 16:33:19 +08:00
Yeuoly
1907d791e1
enhance: add gzip 2024-09-24 16:15:50 +08:00
Yeuoly
fb3a701c86
fix: stream with empty line 2024-09-24 16:02:01 +08:00
Yeuoly
947bfdc807
feat: validate credentials 2024-09-23 21:13:02 +08:00
Yeuoly
7a3e756020
refactor: list tools 2024-09-23 18:06:16 +08:00
Yeuoly
435e71eb60
refactor 2024-09-23 13:09:46 +08:00
Yeuoly
91cb80f795
refactor: tool 2024-09-20 23:48:48 +08:00
Yeuoly
3c1d32e3ac
feat: uninstall plugin 2024-09-20 21:50:44 +08:00
Yeuoly
eef79a5196
feat: support install plugin 2024-09-20 21:35:19 +08:00
Yeuoly
2223dfb266
feat: get debugging key 2024-09-20 15:08:39 +08:00
Yeuoly
9693b5ad0c
feat: debugging key 2024-09-20 14:43:01 +08:00
Yeuoly
d4bf575d0a
impl: basic plugin manager 2024-09-20 13:55:09 +08:00
Yeuoly
73ce692e24
feat: add inner api key 2024-09-20 13:32:11 +08:00
Yeuoly
661392eaef
refactor: tool 2024-09-20 02:25:14 +08:00
Yeuoly
c472ea6c67
fix: pydantic 2024-09-19 18:02:24 +08:00
Yeuoly
4eaba3049a
Merge main 2024-09-19 17:54:08 +08:00
Yeuoly
00d1c45518
Merge main 2024-09-14 02:47:01 +08:00
Yeuoly
87c746f6bb
tmp 2024-09-14 01:26:22 +08:00
Yeuoly
70c001436e
support variable 2024-09-10 18:13:33 +08:00
Yeuoly
cf73374c1b
refactor: stream output 2024-09-10 17:16:55 +08:00
Yeuoly
b0d53c0ac4
Merge main 2024-09-10 15:42:59 +08:00
Yeuoly
9c7bcd5abc
Merge main 2024-09-10 14:05:20 +08:00
Yeuoly
b7c5abc5dd
reformatter 2024-08-30 23:29:04 +08:00
Yeuoly
de01ca8d55
feat: inner api encrypt 2024-08-30 21:25:58 +08:00
Yeuoly
60e75dc748
fix: linter 2024-08-30 21:11:39 +08:00
Yeuoly
279dee485d
feat: type 2024-08-30 21:10:19 +08:00
Yeuoly
db8bf2a85e
Merge branch 'main' into feat/plugin 2024-08-30 18:21:22 +08:00
Yeuoly
46ba16fe90
fix: reformatter 2024-08-30 18:21:03 +08:00
Yeuoly
886a160115
fix: invoke tool streamingly 2024-08-30 18:11:38 +08:00
Yeuoly
cf4e9f317e
refactor: tool models 2024-08-30 15:55:10 +08:00
Yeuoly
1fa3b9cfd8
refactor tools 2024-08-30 14:23:14 +08:00
Yeuoly
50a5cfe56a
fix: endpoint using default user 2024-08-29 21:48:20 +08:00
Yeuoly
ece82b87bf
feat: invoke app 2024-08-29 21:14:23 +08:00
Yeuoly
12ea085e22
feat: implement invoke app args 2024-08-29 20:50:36 +08:00
Yeuoly
41ed2e0cc2
feat: backwards invoke app 2024-08-29 20:17:17 +08:00
Yeuoly
113ff27d07
fix: types 2024-08-29 20:06:14 +08:00
Yeuoly
ec711d094d
refactor: enforce return object in app generator 2024-08-29 19:49:57 +08:00
Yeuoly
a073de44e9
Merge branch 'main' into feat/plugin 2024-08-29 17:08:44 +08:00
Yeuoly
6ce02b07d3
feat: add type annatation 2024-08-29 14:23:19 +08:00
Yeuoly
f47712beae
feat: add type annatation 2024-08-29 14:18:00 +08:00
Yeuoly
4a8d3c54ca
fix: workflow as tool type 2024-08-29 14:09:47 +08:00
Yeuoly
c8b0160ea9
fix: tool type 2024-08-29 14:06:10 +08:00
Yeuoly
531ffaec4f
fix: tool node 2024-08-29 13:56:48 +08:00
Yeuoly
c28998a6f0
refactor: tool message transformer 2024-08-29 13:42:31 +08:00
Yeuoly
4b4741f7ed
Merge main into feat/plugin 2024-08-29 13:09:13 +08:00
Yeuoly
25b8a512bf
feat: invoke app 2024-08-29 12:55:00 +08:00
Yeuoly
02d26818ad
Merge branch 'main' into feat/plugin 2024-07-31 14:51:36 +08:00
Yeuoly
31e8b134d1
feat: backwards invoke llm 2024-07-29 22:08:14 +08:00
Yeuoly
d52476c1c9
feat: support backwards invocation 2024-07-29 18:57:34 +08:00
Yeuoly
f29b44acd8
feat: support plugin inner api 2024-07-29 16:40:04 +08:00
Yeuoly
ed7fcc5f7d
Merge branch 'main' into feat/plugin 2024-07-29 16:07:19 +08:00
Yeuoly
c6f34f5c17
Merge branch 'main' into feat/plugin 2024-07-15 16:03:11 +08:00
Yeuoly
e1db77eec2
fix 2024-07-15 16:00:11 +08:00
Yeuoly
563d81277b
refactor: tool response to generator 2024-07-09 15:37:56 +08:00
Yeuoly
364df36ac4
feat: plugin call dify 2024-07-08 22:37:20 +08:00
3140 changed files with 61372 additions and 275190 deletions

View File

@ -1,11 +1,12 @@
#!/bin/bash #!/bin/bash
cd web && npm install npm add -g pnpm@9.12.2
cd web && pnpm install
pipx install poetry pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc

View File

@ -4,6 +4,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- plugins/beta
paths: paths:
- api/** - api/**
- docker/** - docker/**
@ -47,15 +48,9 @@ jobs:
- name: Run Unit tests - name: Run Unit tests
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests - name: Run dify config tests
run: poetry run -C api python dev/pytest/pytest_config_tests.py run: poetry run -C api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh
- name: Run mypy - name: Run mypy
run: | run: |
pushd api pushd api

View File

@ -5,6 +5,7 @@ on:
branches: branches:
- "main" - "main"
- "deploy/dev" - "deploy/dev"
- "plugins/beta"
release: release:
types: [published] types: [published]

View File

@ -4,6 +4,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- plugins/beta
paths: paths:
- api/migrations/** - api/migrations/**
- .github/workflows/db-migration-test.yml - .github/workflows/db-migration-test.yml

View File

@ -4,6 +4,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- plugins/beta
concurrency: concurrency:
group: style-${{ github.head_ref || github.run_id }} group: style-${{ github.head_ref || github.run_id }}
@ -71,17 +72,16 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 20
cache: yarn cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json
- name: Web dependencies - name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile run: pnpm install --frozen-lockfile
- name: Web style check - name: Web style check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: yarn run lint run: pnpm run lint
superlinter: superlinter:
name: SuperLinter name: SuperLinter
@ -107,7 +107,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
env: env:
BASH_SEVERITY: warning BASH_SEVERITY: warning
DEFAULT_BRANCH: main DEFAULT_BRANCH: plugins/beta
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true IGNORE_GITIGNORED_FILES: true

View File

@ -32,10 +32,10 @@ jobs:
with: with:
node-version: ${{ matrix.node-version }} node-version: ${{ matrix.node-version }}
cache: '' cache: ''
cache-dependency-path: 'yarn.lock' cache-dependency-path: 'pnpm-lock.yaml'
- name: Install Dependencies - name: Install Dependencies
run: yarn install run: pnpm install
- name: Test - name: Test
run: yarn test run: pnpm test

View File

@ -38,11 +38,11 @@ jobs:
- name: Install dependencies - name: Install dependencies
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile run: pnpm install --frozen-lockfile
- name: Run npm script - name: Run npm script
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n run: pnpm run auto-gen-i18n
- name: Create Pull Request - name: Create Pull Request
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'

View File

@ -34,13 +34,13 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 20
cache: yarn cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json
- name: Install dependencies - name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile run: pnpm install --frozen-lockfile
- name: Run tests - name: Run tests
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: yarn test run: pnpm test

4
.gitignore vendored
View File

@ -175,6 +175,7 @@ docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/* docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/* docker/volumes/couchbase/*
docker/volumes/oceanbase/* docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
!docker/volumes/oceanbase/init.d !docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf
@ -193,3 +194,6 @@ api/.vscode
.idea/ .idea/
.vscode .vscode
# pnpm
/.pnpm-store

View File

@ -1,7 +1,10 @@
.env .env
*.env.* *.env.*
storage/generate_files/*
storage/privkeys/* storage/privkeys/*
storage/tools/*
storage/upload_files/*
# Logs # Logs
logs logs
@ -9,6 +12,8 @@ logs
# jetbrains # jetbrains
.idea .idea
.mypy_cache
.ruff_cache
# venv # venv
.venv .venv

View File

@ -409,7 +409,6 @@ MAX_VARIABLE_SIZE=204800
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration # Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1 CELERY_BEAT_SCHEDULER_TIME=1
@ -422,6 +421,22 @@ POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES= POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES= POSITION_PROVIDER_EXCLUDES=
# Plugin configuration
PLUGIN_API_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
PLUGIN_API_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
# Reset password token expiry minutes # Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5

View File

@ -69,6 +69,10 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data # Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
# Copy source code # Copy source code
COPY . /app/api/ COPY . /app/api/

View File

@ -25,6 +25,8 @@ from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
@click.command("reset-password", help="Reset the account password.") @click.command("reset-password", help="Reset the account password.")
@ -524,7 +526,7 @@ def add_qdrant_doc_id_index(field: str):
) )
) )
except Exception as e: except Exception:
click.echo(click.style("Failed to create Qdrant client.", fg="red")) click.echo(click.style("Failed to create Qdrant client.", fg="red"))
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
@ -593,7 +595,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green")) click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e: except Exception:
logging.exception("Failed to execute database migration") logging.exception("Failed to execute database migration")
finally: finally:
lock.release() lock.release()
@ -639,7 +641,7 @@ where sites.id is null limit 1000"""
account = accounts[0] account = accounts[0]
print("Fixing missing site for app {}".format(app.id)) print("Fixing missing site for app {}".format(app.id))
app_was_created.send(app, account=account) app_was_created.send(app, account=account)
except Exception as e: except Exception:
failed_app_ids.append(app_id) failed_app_ids.append(app_id)
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
@ -649,3 +651,68 @@ where sites.id is null limit 1000"""
break break
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
def migrate_data_for_plugin():
"""
Migrate data for plugin.
"""
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
PluginDataMigration.migrate()
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
@click.command("extract-plugins", help="Extract plugins.")
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
def extract_plugins(output_file: str, workers: int):
"""
Extract plugins.
"""
click.echo(click.style("Starting extract plugins.", fg="white"))
PluginMigration.extract_plugins(output_file, workers)
click.echo(click.style("Extract plugins completed.", fg="green"))
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
@click.option(
"--output_file",
prompt=True,
help="The file to store the extracted unique identifiers.",
default="unique_identifiers.json",
)
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
def extract_unique_plugins(output_file: str, input_file: str):
"""
Extract unique plugins.
"""
click.echo(click.style("Starting extract unique plugins.", fg="white"))
PluginMigration.extract_unique_plugins_to_file(input_file, output_file)
click.echo(click.style("Extract unique plugins completed.", fg="green"))
@click.command("install-plugins", help="Install plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
def install_plugins(input_file: str, output_file: str):
"""
Install plugins.
"""
click.echo(click.style("Starting install plugins.", fg="white"))
PluginMigration.install_plugins(input_file, output_file)
click.echo(click.style("Install plugins completed.", fg="green"))

View File

@ -134,6 +134,60 @@ class CodeExecutionSandboxConfig(BaseSettings):
) )
class PluginConfig(BaseSettings):
"""
Plugin configs
"""
PLUGIN_DAEMON_URL: HttpUrl = Field(
description="Plugin API URL",
default="http://plugin:5002",
)
PLUGIN_API_KEY: str = Field(
description="Plugin API key",
default="plugin-api-key",
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
description="Plugin Remote Install Host",
default="localhost",
)
PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
description="Plugin Remote Install Port",
default=5003,
)
PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin packages in bytes",
default=15728640,
)
PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin bundles in bytes",
default=15728640 * 12,
)
class MarketplaceConfig(BaseSettings):
"""
Configuration for marketplace
"""
MARKETPLACE_ENABLED: bool = Field(
description="Enable or disable marketplace",
default=True,
)
MARKETPLACE_API_URL: HttpUrl = Field(
description="Marketplace API URL",
default="https://marketplace.dify.ai",
)
class EndpointConfig(BaseSettings): class EndpointConfig(BaseSettings):
""" """
Configuration for various application endpoints and URLs Configuration for various application endpoints and URLs
@ -160,6 +214,10 @@ class EndpointConfig(BaseSettings):
default="", default="",
) )
ENDPOINT_URL_TEMPLATE: str = Field(
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
)
class FileAccessConfig(BaseSettings): class FileAccessConfig(BaseSettings):
""" """
@ -788,6 +846,8 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig, BillingConfig,
CodeExecutionSandboxConfig, CodeExecutionSandboxConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig, DataSetConfig,
EndpointConfig, EndpointConfig,
FileAccessConfig, FileAccessConfig,

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.15.0", default="1.0.0-beta.1",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

View File

@ -1,9 +1,19 @@
from contextvars import ContextVar from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id") tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")

View File

@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import ( from .explore.conversation import (
@ -40,6 +40,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App # Import App
api.add_resource(AppImportApi, "/apps/imports") api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm") api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version from . import admin, apikey, extension, feature, ping, setup, version
@ -166,4 +167,15 @@ api.add_resource(
from .tag import tags from .tag import tags
# Import workspace controllers # Import workspace controllers
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace from .workspace import (
account,
agent_providers,
endpoint,
load_balancing_config,
members,
model_providers,
models,
plugin,
tool_providers,
workspace,
)

View File

@ -2,6 +2,8 @@ from functools import wraps
from flask import request from flask import request
from flask_restful import Resource, reqparse # type: ignore from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config from configs import dify_config
@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
app = App.query.filter(App.id == args["app_id"]).first() with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app: if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found') raise NotFound(f'App \'{args["app_id"]}\' is not found')
@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or args["privacy_policy"] or "" privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app: if not recommended_app:
recommended_app = RecommendedApp( recommended_app = RecommendedApp(
@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def delete(self, app_id): def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app: if not recommended_app:
return {"result": "success"}, 204 return {"result": "success"}, 204
app = App.query.filter(App.id == recommended_app.app_id).first() with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
if app: if app:
app.is_public = False app.is_public = False
installed_apps = InstalledApp.query.filter( with Session(db.engine) as session:
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id installed_apps = session.execute(
).all() select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
).all()
for installed_app in installed_apps: for installed_app in installed_apps:
db.session.delete(installed_app) db.session.delete(installed_app)

View File

@ -3,6 +3,8 @@ from typing import Any
import flask_restful # type: ignore import flask_restful # type: ignore
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with from flask_restful import Resource, fields, marshal_with
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
@ -26,7 +28,16 @@ api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="it
def _get_resource(resource_id, tenant_id, resource_model): def _get_resource(resource_id, tenant_id, resource_model):
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None: if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.") flask_restful.abort(404, message=f"{resource_model.__name__} not found.")

View File

@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_import_fields from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required from libs.login import login_required
from models import Account from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus
@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource):
if result.status == ImportStatus.FAILED.value: if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)
return result.model_dump(mode="json"), 200

View File

@ -2,6 +2,7 @@ from datetime import UTC, datetime
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language from constants.languages import supported_language
@ -50,33 +51,37 @@ class AppSite(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
site = Site.query.filter(Site.app_id == app_model.id).one_or_404() with Session(db.engine) as session:
site = session.query(Site).filter(Site.app_id == app_model.id).first()
for attr_name in [ if not site:
"title", raise NotFound
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
site.updated_by = current_user.id for attr_name in [
site.updated_at = datetime.now(UTC).replace(tzinfo=None) "title",
db.session.commit() "icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.commit()
return site return site

View File

@ -20,6 +20,7 @@ from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App from models import App
from models.account import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
@ -96,6 +97,9 @@ class DraftWorkflowApi(Resource):
else: else:
abort(415) abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
@ -139,6 +143,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="") parser.add_argument("query", type=str, required=True, location="json", default="")
@ -160,7 +167,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
raise ConversationCompletedError() raise ConversationCompletedError()
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
@ -178,38 +185,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() if not isinstance(current_user, Account):
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -228,7 +204,44 @@ class WorkflowDraftRunIterationNodeApi(Resource):
raise ConversationCompletedError() raise ConversationCompletedError()
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
@ -246,6 +259,9 @@ class DraftWorkflowRunApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("files", type=list, required=False, location="json")
@ -294,13 +310,20 @@ class DraftWorkflowNodeRunApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node( workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
) )
return workflow_node_execution return workflow_node_execution
@ -339,6 +362,9 @@ class PublishedWorkflowApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
@ -376,12 +402,17 @@ class DefaultBlockConfigApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args") parser.add_argument("q", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
q = args.get("q")
filters = None filters = None
if args.get("q"): if q:
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(args.get("q", ""))
except json.JSONDecodeError: except json.JSONDecodeError:
@ -407,6 +438,9 @@ class ConvertToWorkflowApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
if request.data: if request.data:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json") parser.add_argument("name", type=str, required=False, nullable=True, location="json")

View File

@ -3,6 +3,8 @@ import secrets
from flask import request from flask import request
from flask_restful import Resource, reqparse # type: ignore from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api
@ -37,7 +39,8 @@ class ForgotPasswordSendEmailApi(Resource):
else: else:
language = "en-US" language = "en-US"
account = Account.query.filter_by(email=args["email"]).first() with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None token = None
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
@ -104,7 +107,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt) password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first() with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
if account: if account:
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
@ -125,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
pass pass
except AccountRegisterError as are: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
return {"result": "success"} return {"result": "success"}

View File

@ -5,6 +5,8 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Optional[Account] = Account.get_by_openid(provider, user_info.id) account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
if not account: if not account:
account = Account.query.filter_by(email=user_info.email).first() with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
return account return account

View File

@ -4,6 +4,8 @@ import json
from flask import request from flask import request
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
@ -76,7 +78,10 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action): def patch(self, binding_id, action):
binding_id = str(binding_id) binding_id = str(binding_id)
action = str(action) action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None: if data_source_binding is None:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")
# enable binding # enable binding
@ -108,47 +113,53 @@ class DataSourceNotionListApi(Resource):
def get(self): def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = [] exist_page_ids = []
# import notion in the exist dataset with Session(db.engine) as session:
if dataset_id: # import notion in the exist dataset
dataset = DatasetService.get_dataset(dataset_id) if dataset_id:
if not dataset: dataset = DatasetService.get_dataset(dataset_id)
raise NotFound("Dataset not found.") if not dataset:
if dataset.data_source_type != "notion_import": raise NotFound("Dataset not found.")
raise ValueError("Dataset is not notion type.") if dataset.data_source_type != "notion_import":
documents = Document.query.filter_by( raise ValueError("Dataset is not notion type.")
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id, documents = session.execute(
data_source_type="notion_import", select(Document).filter_by(
enabled=True, dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = session.execute(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
).all() ).all()
if documents: if not data_source_bindings:
for document in documents: return {"notion_info": []}, 200
data_source_info = json.loads(document.data_source_info) pre_import_info_list = []
exist_page_ids.append(data_source_info["notion_page_id"]) for data_source_binding in data_source_bindings:
# get all authorized pages source_info = data_source_binding.source_info
data_source_bindings = DataSourceOauthBinding.query.filter_by( pages = source_info["pages"]
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False # Filter out already bound pages
).all() for page in pages:
if not data_source_bindings: if page["page_id"] in exist_page_ids:
return {"notion_info": []}, 200 page["is_bound"] = True
pre_import_info_list = [] else:
for data_source_binding in data_source_bindings: page["is_bound"] = False
source_info = data_source_binding.source_info pre_import_info = {
pages = source_info["pages"] "workspace_name": source_info["workspace_name"],
# Filter out already bound pages "workspace_icon": source_info["workspace_icon"],
for page in pages: "workspace_id": source_info["workspace_id"],
if page["page_id"] in exist_page_ids: "pages": pages,
page["is_bound"] = True }
else: pre_import_info_list.append(pre_import_info)
page["is_bound"] = False return {"notion_info": pre_import_info_list}, 200
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
@ -158,14 +169,17 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id) workspace_id = str(workspace_id)
page_id = str(page_id) page_id = str(page_id)
data_source_binding = DataSourceOauthBinding.query.filter( with Session(db.engine) as session:
db.and_( data_source_binding = session.execute(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, select(DataSourceOauthBinding).filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.provider == "notion",
) DataSourceOauthBinding.disabled == False,
).first() DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding: if not data_source_binding:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")

View File

@ -7,7 +7,6 @@ from flask import request
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import asc, desc from sqlalchemy import asc, desc
from transformers.hf_argparser import string_to_bool # type: ignore
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -40,6 +39,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -150,8 +150,20 @@ class DatasetDocumentListApi(Resource):
sort = request.args.get("sort", default="-created_at", type=str) sort = request.args.get("sort", default="-created_at", type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False. # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try: try:
fetch = string_to_bool(request.args.get("fetch", default="false")) fetch_val = request.args.get("fetch", default="false")
except (ArgumentTypeError, ValueError, Exception) as e: if isinstance(fetch_val, bool):
fetch = fetch_val
else:
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
fetch = True
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
fetch = False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
f"(case insensitive)."
)
except (ArgumentTypeError, ValueError, Exception):
fetch = False fetch = False
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
@ -430,6 +442,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
@ -531,6 +545,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))

View File

@ -2,8 +2,11 @@ import os
from flask import session from flask import session
from flask_restful import Resource, reqparse # type: ignore from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
def get_init_validate_status(): def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"): if os.environ.get("INIT_PASSWORD"):
return session.get("is_init_validated") or DifySetup.query.first() if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True return True

View File

@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from . import api from . import api
@ -52,8 +52,9 @@ class SetupApi(Resource):
def get_setup_status(): def get_setup_status():
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
return DifySetup.query.first() return db.session.query(DifySetup).first()
return True else:
return True
api.add_resource(SetupApi, "/setup") api.add_resource(SetupApi, "/setup")

View File

@ -0,0 +1,56 @@
from functools import wraps
from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
user = current_user
tenant_id = user.current_tenant_id
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.filter(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()
)
if not permission:
# no permission set, allow access for everyone
return view(*args, **kwargs)
if install_required:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)
return decorated
return interceptor

View File

@ -0,0 +1,36 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.agent_service import AgentService
class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")

View File

@ -0,0 +1,205 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService
class EndpointCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]
return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
)
}
class EndpointListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()
page = args["page"]
page_size = args["page_size"]
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
)
}
)
class EndpointListForSinglePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
parser.add_argument("plugin_id", type=str, required=True, location="args")
args = parser.parse_args()
page = args["page"]
page_size = args["page_size"]
plugin_id = args["plugin_id"]
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
)
}
)
class EndpointDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
if not user.is_admin_or_owner:
raise Forbidden()
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
class EndpointUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
settings=settings,
)
}
class EndpointEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
class EndpointDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")

View File

@ -112,10 +112,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
# Load Balancing Config # Load Balancing Config
api.add_resource( api.add_resource(
LoadBalancingCredentialsValidateApi, LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate", "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
) )
api.add_resource( api.add_resource(
LoadBalancingConfigCredentialsValidateApi, LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate", "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
) )

View File

@ -79,7 +79,7 @@ class ModelProviderValidateApi(Resource):
response = {"result": "success" if result else "error"} response = {"result": "success" if result else "error"}
if not result: if not result:
response["error"] = error response["error"] = error or "Unknown error"
return response return response
@ -125,9 +125,10 @@ class ModelProviderIconApi(Resource):
Get model provider icon Get model provider icon
""" """
def get(self, provider: str, icon_type: str, lang: str): def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon( icon, mimetype = model_provider_service.get_model_provider_icon(
tenant_id=tenant_id,
provider=provider, provider=provider,
icon_type=icon_type, icon_type=icon_type,
lang=lang, lang=lang,
@ -183,53 +184,17 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
return data return data
class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
return result
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
args = parser.parse_args()
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
)
return result
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials") api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate") api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>") api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
)
api.add_resource( api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type" PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
) )
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource( api.add_resource(
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url" ModelProviderIconApi,
) "/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
api.add_resource(
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
)
api.add_resource(
ModelProviderFreeQuotaQualificationVerifyApi,
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
) )

View File

@ -325,7 +325,7 @@ class ModelProviderModelValidateApi(Resource):
response = {"result": "success" if result else "error"} response = {"result": "success" if result else "error"}
if not result: if not result:
response["error"] = error response["error"] = error or ""
return response return response
@ -362,26 +362,26 @@ class ModelProviderAvailableModelApi(Resource):
return jsonable_encoder({"data": models}) return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models") api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource( api.add_resource(
ModelProviderModelEnableApi, ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<string:provider>/models/enable", "/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable", endpoint="model-provider-model-enable",
) )
api.add_resource( api.add_resource(
ModelProviderModelDisableApi, ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<string:provider>/models/disable", "/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable", endpoint="model-provider-model-disable",
) )
api.add_resource( api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials" ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
) )
api.add_resource( api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate" ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
) )
api.add_resource( api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules" ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
) )
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model") api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@ -0,0 +1,475 @@
import io
from flask import request, send_file
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.manager.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
try:
plugins = PluginService.list(tenant_id)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
class PluginIconApi(Resource):
@setup_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("tenant_id", type=str, required=True, location="args")
req.add_argument("filename", type=str, required=True, location="args")
args = req.parse_args()
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
file = request.files["pkg"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
file = request.files["bundle"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_bundle(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try:
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()
try:
response = PluginService.install_from_github(
tenant_id,
args["plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_plugin_manifest(
tenant_id, args["plugin_unique_identifier"]
).model_dump()
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
args["original_plugin_unique_identifier"],
args["new_plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginUninstallApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("install_permission", type=str, required=True, location="json")
req.add_argument("debug_permission", type=str, required=True, location="json")
args = req.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
return jsonable_encoder(
{
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
)
return jsonable_encoder(
{
"install_permission": permission.install_permission,
"debug_permission": permission.debug_permission,
}
)
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

View File

@ -25,8 +25,10 @@ class ToolProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser() req = reqparse.RequestParser()
req.add_argument( req.add_argument(
@ -47,28 +49,43 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools( BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id, tenant_id,
provider, provider,
) )
) )
class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource): class ToolBuiltinProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
return BuiltinToolManageService.delete_builtin_tool_provider( return BuiltinToolManageService.delete_builtin_tool_provider(
user_id, user_id,
@ -82,11 +99,13 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -131,11 +150,13 @@ class ToolApiProviderAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -168,6 +189,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="args") parser.add_argument("url", type=str, required=True, nullable=False, location="args")
@ -175,8 +201,8 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
args = parser.parse_args() args = parser.parse_args()
return ApiToolManageService.get_api_tool_provider_remote_schema( return ApiToolManageService.get_api_tool_provider_remote_schema(
current_user.id, user_id,
current_user.current_tenant_id, tenant_id,
args["url"], args["url"],
) )
@ -186,8 +212,10 @@ class ToolApiProviderListToolsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -209,11 +237,13 @@ class ToolApiProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -248,11 +278,13 @@ class ToolApiProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -272,8 +304,10 @@ class ToolApiProviderGetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -293,7 +327,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@ -344,11 +382,13 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@ -381,11 +421,13 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -421,11 +463,13 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner: user = current_user
if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = current_user.id user_id = user.id
tenant_id = current_user.current_tenant_id tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -444,8 +488,10 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@ -476,8 +522,10 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@ -498,8 +546,10 @@ class ToolBuiltinListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -517,8 +567,10 @@ class ToolApiListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -536,8 +588,10 @@ class ToolWorkflowListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id user = current_user
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -563,16 +617,18 @@ class ToolLabelsApi(Resource):
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# builtin tool provider # builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools") api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource( api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials" ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
) )
api.add_resource( api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema" ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
) )
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon") api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider # api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")

View File

@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore
from configs import dify_config from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from models.model import DifySetup from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService from services.operation_service import OperationService
@ -134,9 +135,13 @@ def setup_required(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
# check setup # check setup
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first(): if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
raise NotInitValidateError() raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first(): elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
raise NotSetupError() raise NotSetupError()
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -6,4 +6,4 @@ bp = Blueprint("files", __name__)
api = ExternalApi(bp) api = ExternalApi(bp)
from . import image_preview, tool_files from . import image_preview, tool_files, upload

View File

@ -0,0 +1,69 @@
from flask import request
from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden
import services
from controllers.console.wraps import setup_required
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from controllers.inner_api.plugin.wraps import get_user
from controllers.service_api.app.error import FileTooLargeError
from core.file.helpers import verify_plugin_file_signature
from fields.file_fields import file_fields
from services.file_service import FileService
class PluginUploadFileApi(Resource):
@setup_required
@marshal_with(file_fields)
def post(self):
# get file from request
file = request.files["file"]
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
tenant_id = request.args.get("tenant_id")
if not tenant_id:
raise Forbidden("Invalid request.")
user_id = request.args.get("user_id")
user = get_user(tenant_id, user_id)
filename = file.filename
mimetype = file.mimetype
if not filename or not mimetype:
raise Forbidden("Invalid request.")
if not timestamp or not nonce or not sign:
raise Forbidden("Invalid request.")
if not verify_plugin_file_signature(
filename=filename,
mimetype=mimetype,
tenant_id=tenant_id,
user_id=user_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
):
raise Forbidden("Invalid request.")
try:
upload_file = FileService.upload_file(
filename=filename,
content=file.read(),
mimetype=mimetype,
user=user,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")

View File

@ -5,4 +5,5 @@ from libs.external_api import ExternalApi
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
api = ExternalApi(bp) api = ExternalApi(bp)
from .plugin import plugin
from .workspace import workspace from .workspace import workspace

View File

@ -0,0 +1,293 @@
from flask_restful import Resource # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only
from core.file.helpers import get_signed_file_url_for_plugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeSummary,
RequestInvokeTextEmbedding,
RequestInvokeTool,
RequestInvokeTTS,
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import compact_generate_response
from models.account import Account, Tenant
from models.model import EndUser
class PluginInvokeLLMApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLM)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
class PluginInvokeTextEmbeddingApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_text_embedding(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeRerankApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeRerank)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_rerank(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeTTSApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTTS)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
def generator():
response = PluginModelBackwardsInvocation.invoke_tts(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
class PluginInvokeSpeech2TextApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSpeech2Text)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_speech2text(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeModerationApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeModeration)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_moderation(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeToolApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTool)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
def generator():
return PluginToolBackwardsInvocation.convert_to_event_stream(
PluginToolBackwardsInvocation.invoke_tool(
tenant_id=tenant_model.id,
user_id=user_model.id,
tool_type=ToolProviderType.value_of(payload.tool_type),
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
),
)
return compact_generate_response(generator())
class PluginInvokeParameterExtractorNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
tenant_id=tenant_model.id,
user_id=user_model.id,
parameters=payload.parameters,
model_config=payload.model,
instruction=payload.instruction,
query=payload.query,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeQuestionClassifierNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
tenant_id=tenant_model.id,
user_id=user_model.id,
query=payload.query,
model_config=payload.model,
classes=payload.classes,
instruction=payload.instruction,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
class PluginInvokeAppApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeApp)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
response = PluginAppBackwardsInvocation.invoke_app(
app_id=payload.app_id,
user_id=user_model.id,
tenant_id=tenant_model.id,
conversation_id=payload.conversation_id,
query=payload.query,
stream=payload.response_mode == "streaming",
inputs=payload.inputs,
files=payload.files,
)
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
class PluginInvokeEncryptApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeEncrypt)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
"""
encrypt or decrypt data
"""
try:
return BaseBackwardsInvocationResponse(
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
class PluginInvokeSummaryApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSummary)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
try:
return BaseBackwardsInvocationResponse(
data={
"summary": PluginModelBackwardsInvocation.invoke_summary(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
}
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
class PluginUploadFileRequestApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestRequestUploadFile)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
# generate signed url
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")

View File

@ -0,0 +1,116 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from flask import request
from flask_restful import reqparse # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import Account, Tenant
from models.model import EndUser
from services.account_service import AccountService
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
try:
with Session(db.engine) as session:
if not user_id:
user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=True if user_id == "DEFAULT-USER" else False,
session_id=user_id,
)
session.add(user_model)
session.commit()
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
raise ValueError("user not found")
return user_model
def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")
kwargs = parser.parse_args()
user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
if not tenant_id:
raise ValueError("tenant_id is required")
if not user_id:
user_id = "DEFAULT-USER"
del kwargs["tenant_id"]
del kwargs["user_id"]
try:
tenant_model = (
db.session.query(Tenant)
.filter(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(tenant_id, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
try:
data = request.get_json()
except Exception:
raise ValueError("invalid json")
try:
payload = payload_type(**data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")
kwargs["payload"] = payload
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -2,7 +2,7 @@ from flask_restful import Resource, reqparse # type: ignore
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import api from controllers.inner_api import api
from controllers.inner_api.wraps import inner_api_only from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from models.account import Account from models.account import Account
from services.account_service import TenantService from services.account_service import TenantService
@ -10,7 +10,7 @@ from services.account_service import TenantService
class EnterpriseWorkspace(Resource): class EnterpriseWorkspace(Resource):
@setup_required @setup_required
@inner_api_only @enterprise_inner_api_only
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")

View File

@ -10,7 +10,7 @@ from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
def inner_api_only(view): def enterprise_inner_api_only(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:
@ -18,7 +18,7 @@ def inner_api_only(view):
# get header 'X-Inner-Api-Key' # get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key") inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
abort(401) abort(401)
return view(*args, **kwargs) return view(*args, **kwargs)
@ -26,7 +26,7 @@ def inner_api_only(view):
return decorated return decorated
def inner_api_user_auth(view): def enterprise_inner_api_user_auth(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:
@ -60,3 +60,19 @@ def inner_api_user_auth(view):
return view(*args, **kwargs) return view(*args, **kwargs)
return decorated return decorated
def plugin_inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.PLUGIN_API_KEY:
abort(404)
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
abort(404)
return view(*args, **kwargs)
return decorated

View File

@ -1,7 +1,6 @@
import json import json
import logging import logging
import uuid import uuid
from datetime import UTC, datetime
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity from core.agent.entities import AgentEntity, AgentToolEntity
@ -32,19 +31,16 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolParameter, ToolParameter,
ToolRuntimeVariablePool,
) )
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile from models.model import Conversation, Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,11 +58,9 @@ class BaseAgentRunner(AppRunner):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
user_id: str, user_id: str,
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None, prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance,
) -> None: ) -> None:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
@ -79,8 +73,6 @@ class BaseAgentRunner(AppRunner):
self.user_id = user_id self.user_id = user_id
self.memory = memory self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance self.model_instance = model_instance
# init callback # init callback
@ -141,11 +133,10 @@ class BaseAgentRunner(AppRunner):
agent_tool=tool, agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
) )
tool_entity.load_variables(self.variables_pool) assert tool_entity.entity.description
message_tool = PromptMessageTool( message_tool = PromptMessageTool(
name=tool.tool_name, name=tool.tool_name,
description=tool_entity.description.llm if tool_entity.description else "", description=tool_entity.entity.description.llm,
parameters={ parameters={
"type": "object", "type": "object",
"properties": {}, "properties": {},
@ -153,7 +144,7 @@ class BaseAgentRunner(AppRunner):
}, },
) )
parameters = tool_entity.get_all_runtime_parameters() parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters: for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM: if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue continue
@ -186,9 +177,11 @@ class BaseAgentRunner(AppRunner):
""" """
convert dataset retriever tool to prompt message tool convert dataset retriever tool to prompt message tool
""" """
assert tool.entity.description
prompt_tool = PromptMessageTool( prompt_tool = PromptMessageTool(
name=tool.identity.name if tool.identity else "unknown", name=tool.entity.identity.name,
description=tool.description.llm if tool.description else "", description=tool.entity.description.llm,
parameters={ parameters={
"type": "object", "type": "object",
"properties": {}, "properties": {},
@ -234,8 +227,7 @@ class BaseAgentRunner(AppRunner):
# save prompt tool # save prompt tool
prompt_messages_tools.append(prompt_tool) prompt_messages_tools.append(prompt_tool)
# save tool entity # save tool entity
if dataset_tool.identity is not None: tool_instances[dataset_tool.entity.identity.name] = dataset_tool
tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools return tool_instances, prompt_messages_tools
@ -320,24 +312,23 @@ class BaseAgentRunner(AppRunner):
def save_agent_thought( def save_agent_thought(
self, self,
agent_thought: MessageAgentThought, agent_thought: MessageAgentThought,
tool_name: str, tool_name: str | None,
tool_input: Union[str, dict], tool_input: Union[str, dict, None],
thought: str, thought: str | None,
observation: Union[str, dict, None], observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict, None], tool_invoke_meta: Union[str, dict, None],
answer: str, answer: str | None,
messages_ids: list[str], messages_ids: list[str],
llm_usage: LLMUsage | None = None, llm_usage: LLMUsage | None = None,
): ):
""" """
Save agent thought Save agent thought
""" """
queried_thought = ( updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
) )
if not queried_thought: if not updated_agent_thought:
raise ValueError(f"Agent thought {agent_thought.id} not found") raise ValueError("agent thought not found")
agent_thought = queried_thought
if thought: if thought:
agent_thought.thought = thought agent_thought.thought = thought
@ -349,39 +340,39 @@ class BaseAgentRunner(AppRunner):
if isinstance(tool_input, dict): if isinstance(tool_input, dict):
try: try:
tool_input = json.dumps(tool_input, ensure_ascii=False) tool_input = json.dumps(tool_input, ensure_ascii=False)
except Exception as e: except Exception:
tool_input = json.dumps(tool_input) tool_input = json.dumps(tool_input)
agent_thought.tool_input = tool_input updated_agent_thought.tool_input = tool_input
if observation: if observation:
if isinstance(observation, dict): if isinstance(observation, dict):
try: try:
observation = json.dumps(observation, ensure_ascii=False) observation = json.dumps(observation, ensure_ascii=False)
except Exception as e: except Exception:
observation = json.dumps(observation) observation = json.dumps(observation)
agent_thought.observation = observation updated_agent_thought.observation = observation
if answer: if answer:
agent_thought.answer = answer agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0: if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids) updated_agent_thought.message_files = json.dumps(messages_ids)
if llm_usage: if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens updated_agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens updated_agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens updated_agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price updated_agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty # check if tool labels is not empty
labels = agent_thought.tool_labels or {} labels = updated_agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else [] tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
for tool in tools: for tool in tools:
if not tool: if not tool:
continue continue
@ -392,42 +383,20 @@ class BaseAgentRunner(AppRunner):
else: else:
labels[tool] = {"en_US": tool, "zh_Hans": tool} labels[tool] = {"en_US": tool, "zh_Hans": tool}
agent_thought.tool_labels_str = json.dumps(labels) updated_agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None: if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict): if isinstance(tool_invoke_meta, dict):
try: try:
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
except Exception as e: except Exception:
tool_invoke_meta = json.dumps(tool_invoke_meta) tool_invoke_meta = json.dumps(tool_invoke_meta)
agent_thought.tool_meta_str = tool_invoke_meta updated_agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit() db.session.commit()
db.session.close() db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
queried_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
if not queried_variables:
return
db_variables = queried_variables
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
Organize agent history Organize agent history
@ -464,11 +433,11 @@ class BaseAgentRunner(AppRunner):
tool_call_response: list[ToolPromptMessage] = [] tool_call_response: list[ToolPromptMessage] = []
try: try:
tool_inputs = json.loads(agent_thought.tool_input) tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e: except Exception:
tool_inputs = {tool: {} for tool in tools} tool_inputs = {tool: {} for tool in tools}
try: try:
tool_responses = json.loads(agent_thought.observation) tool_responses = json.loads(agent_thought.observation)
except Exception as e: except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation) tool_responses = dict.fromkeys(tools, agent_thought.observation)
for tool in tools: for tool in tools:
@ -515,7 +484,11 @@ class BaseAgentRunner(AppRunner):
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if not files: if not files:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if message.app_model_config:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
else:
file_extra_config = None
if not file_extra_config: if not file_extra_config:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)

View File

@ -1,6 +1,6 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional from typing import Any, Optional
from core.agent.base_agent_runner import BaseAgentRunner from core.agent.base_agent_runner import BaseAgentRunner
@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import (
) )
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Message from models.model import Message
@ -27,11 +27,11 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC): class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ["wenxin"] _ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] | None = None _historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit] | None = None _agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str = "" # FIXME this must be str for now _instruction: str
_query: str | None = None _query: str
_prompt_messages_tools: list[PromptMessageTool] = [] _prompt_messages_tools: Sequence[PromptMessageTool]
def run( def run(
self, self,
@ -42,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
Run Cot agent application Run Cot agent application
""" """
app_generate_entity = self.application_generate_entity app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity) self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query) self._init_react_state(query)
@ -54,17 +55,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
app_generate_entity.model_conf.stop.append("Observation") app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config app_config = self.app_config
assert app_config.agent
# init instruction # init instruction
inputs = inputs or {} inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
@ -116,14 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[], callbacks=[],
) )
if not isinstance(chunks, Generator): usage_dict: dict[str, Optional[LLMUsage]] = {}
raise ValueError("Expected streaming response from LLM")
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response="", agent_response="",
@ -143,25 +139,25 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if isinstance(chunk, AgentScratchpadUnit.Action): if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk action = chunk
# detect action # detect action
if scratchpad.agent_response is not None: assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump()) scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action scratchpad.action = action
else: else:
if scratchpad.agent_response is not None: assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk scratchpad.agent_response += chunk
if scratchpad.thought is not None: assert scratchpad.thought is not None
scratchpad.thought += chunk scratchpad.thought += chunk
yield LLMResultChunk( yield LLMResultChunk(
model=self.model_config.model, model=self.model_config.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
system_fingerprint="", system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
) )
if scratchpad.thought is not None:
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" assert scratchpad.thought is not None
if self._agent_scratchpad is not None: scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad) self._agent_scratchpad.append(scratchpad)
# get llm usage # get llm usage
if "usage" in usage_dict: if "usage" in usage_dict:
@ -256,8 +252,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
answer=final_answer, answer=final_answer,
messages_ids=[], messages_ids=[],
) )
if self.variables_pool is not None and self.db_variables_pool is not None:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
@ -275,7 +269,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
def _handle_invoke_action( def _handle_invoke_action(
self, self,
action: AgentScratchpadUnit.Action, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool], tool_instances: Mapping[str, Tool],
message_file_ids: list[str], message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]: ) -> tuple[str, ToolInvokeMeta]:
@ -315,11 +309,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
) )
# publish files # publish files
for message_file_id, save_as in message_files: for message_file_id in message_files:
if save_as is not None and self.variables_pool:
# FIXME the save_as type is confusing, it should be a string or not
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
@ -342,7 +332,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
for key, value in inputs.items(): for key, value in inputs.items():
try: try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e: except Exception:
continue continue
return instruction return instruction
@ -379,7 +369,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message return message
def _organize_historic_prompt_messages( def _organize_historic_prompt_messages(
self, current_session_messages: Optional[list[PromptMessage]] = None self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]: ) -> list[PromptMessage]:
""" """
organize historic prompt messages organize historic prompt messages
@ -391,8 +381,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
for message in self.history_prompt_messages: for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage): if isinstance(message, AssistantPromptMessage):
if not current_scratchpad: if not current_scratchpad:
if not isinstance(message.content, str | None): assert isinstance(message.content, str)
raise NotImplementedError("expected str type")
current_scratchpad = AgentScratchpadUnit( current_scratchpad = AgentScratchpadUnit(
agent_response=message.content, agent_response=message.content,
thought=message.content or "I am thinking about how to help you", thought=message.content or "I am thinking about how to help you",
@ -411,9 +400,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
except: except:
pass pass
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
if not current_scratchpad: if current_scratchpad:
continue assert isinstance(message.content, str)
if isinstance(message.content, str):
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
else: else:
raise NotImplementedError("expected str type") raise NotImplementedError("expected str type")

View File

@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner):
""" """
Organize system prompt Organize system prompt
""" """
if not self.app_config.agent: assert self.app_config.agent
raise ValueError("Agent configuration is not set") assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
if not prompt_entity: if not prompt_entity:
@ -83,8 +83,10 @@ class CotChatAgentRunner(CotAgentRunner):
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}" assistant_message.content += f"Final Answer: {unit.agent_response}"
else: else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n" assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str: if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n" assistant_message.content += f"Action: {unit.action_str}\n\n"

View File

@ -1,18 +1,21 @@
from enum import Enum from enum import StrEnum
from typing import Any, Literal, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
class AgentToolEntity(BaseModel): class AgentToolEntity(BaseModel):
""" """
Agent Tool Entity. Agent Tool Entity.
""" """
provider_type: Literal["builtin", "api", "workflow"] provider_type: ToolProviderType
provider_id: str provider_id: str
tool_name: str tool_name: str
tool_parameters: dict[str, Any] = {} tool_parameters: dict[str, Any] = {}
plugin_unique_identifier: str | None = None
class AgentPromptEntity(BaseModel): class AgentPromptEntity(BaseModel):
@ -66,7 +69,7 @@ class AgentEntity(BaseModel):
Agent Entity. Agent Entity.
""" """
class Strategy(Enum): class Strategy(StrEnum):
""" """
Agent Strategy. Agent Strategy.
""" """
@ -78,5 +81,13 @@ class AgentEntity(BaseModel):
model: str model: str
strategy: Strategy strategy: Strategy
prompt: Optional[AgentPromptEntity] = None prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] | None = None tools: Optional[list[AgentToolEntity]] = None
max_iteration: int = 5 max_iteration: int = 5
class AgentInvokeMessage(ToolInvokeMessage):
"""
Agent Invoke Message.
"""
pass

View File

@ -46,18 +46,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""
# get tracing instance # get tracing instance
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage final_llm_usage_dict["usage"] = usage
else: else:
@ -107,7 +109,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None current_llm_usage = None
if self.stream_tool_call and isinstance(chunks, Generator): if isinstance(chunks, Generator):
is_first_chunk = True is_first_chunk = True
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
@ -124,7 +126,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps( tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
) )
except json.JSONDecodeError as e: except json.JSONDecodeError:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
@ -140,7 +142,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = chunk.delta.usage current_llm_usage = chunk.delta.usage
yield chunk yield chunk
elif not self.stream_tool_call and isinstance(chunks, LLMResult): else:
result = chunks result = chunks
# check if there is any tool call # check if there is any tool call
if self.check_blocking_tool_calls(result): if self.check_blocking_tool_calls(result):
@ -151,7 +153,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_call_inputs = json.dumps( tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
) )
except json.JSONDecodeError as e: except json.JSONDecodeError:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
@ -183,8 +185,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
usage=result.usage, usage=result.usage,
), ),
) )
else:
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
assistant_message = AssistantPromptMessage(content="", tool_calls=[]) assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls: if tool_calls:
@ -243,15 +243,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback, agent_tool_callback=self.agent_callback,
trace_manager=trace_manager, trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
) )
# publish files # publish files
for message_file_id, save_as in message_files: for message_file_id in message_files:
if save_as:
if self.variables_pool:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as
)
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
@ -303,8 +300,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1 iteration_step += 1
if self.variables_pool and self.db_variables_pool:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
@ -335,9 +330,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True return True
return False return False
def extract_tool_calls( def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract tool calls from llm result chunk Extract tool calls from llm result chunk
@ -360,7 +353,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
""" """
Extract blocking tool calls from llm result Extract blocking tool calls from llm result
@ -383,9 +376,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls return tool_calls
def _init_system_message( def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
) -> list[PromptMessage]:
""" """
Initialize system message Initialize system message
""" """

View File

@ -0,0 +1,89 @@
import enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from core.entities.parameter_entities import CommonParameterType
from core.plugin.entities.parameters import (
PluginParameter,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolProviderIdentity,
)
class AgentStrategyProviderIdentity(ToolProviderIdentity):
"""
Inherits from ToolProviderIdentity, without any additional fields.
"""
pass
class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
"""
Keep all the types from PluginParameterType
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class AgentStrategyProviderEntity(BaseModel):
identity: AgentStrategyProviderIdentity
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
class AgentStrategyIdentity(ToolIdentity):
"""
Inherits from ToolIdentity, without any additional fields.
"""
pass
class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
return v or []
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
strategies: list[AgentStrategyEntity] = Field(default_factory=list)

View File

@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from collections.abc import Generator, Sequence
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
class BaseAgentStrategy(ABC):
"""
Agent Strategy
"""
def invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
Get the parameters for the agent strategy.
"""
return []
@abstractmethod
def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

View File

@ -0,0 +1,59 @@
from collections.abc import Generator, Sequence
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.manager.agent import PluginAgentManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
class PluginAgentStrategy(BaseAgentStrategy):
"""
Agent Strategy
"""
tenant_id: str
declaration: AgentStrategyEntity
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
self.tenant_id = tenant_id
self.declaration = declaration
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
"""
Initialize the parameters for the agent strategy.
"""
for parameter in self.declaration.parameters:
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params
def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
manager = PluginAgentManager()
initialized_params = self.initialize_parameters(params)
params = convert_parameters_to_plugin_format(initialized_params)
yield from manager.invoke(
tenant_id=self.tenant_id,
user_id=user_id,
agent_provider=self.declaration.identity.provider,
agent_strategy=self.declaration.identity.name,
agent_params=params,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)

View File

@ -4,7 +4,8 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
@ -63,14 +64,14 @@ class ModelConfigConverter:
stop = completion_params["stop"] stop = completion_params["stop"]
del completion_params["stop"] del completion_params["stop"]
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
# get model mode # get model mode
model_mode = model_config.mode model_mode = model_config.mode
if not model_mode: if not model_mode:
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) model_mode = LLMMode.CHAT.value
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = mode_enum.value model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not model_schema: if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")

View File

@ -2,8 +2,9 @@ from collections.abc import Mapping
from typing import Any from typing import Any
from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
@ -53,9 +54,18 @@ class ModelConfigManager:
raise ValueError("model must be of object type") raise ValueError("model must be of object type")
# model.provider # model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
provider_entities = model_provider_factory.get_providers() provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities] model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: if "provider" not in config["model"]:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name # model.name

View File

@ -37,17 +37,6 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator): class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int _dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload @overload
def generate( def generate(
self, self,
@ -65,20 +54,31 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: Literal[True],
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... ) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate( def generate(
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
): ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
""" """
Generate App response. Generate App response.
@ -156,6 +156,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -167,8 +169,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) )
def single_iteration_generate( def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True self,
) -> Mapping[str, Any] | Generator[str, None, None]: app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
""" """
Generate App response. Generate App response.
@ -205,6 +213,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
), ),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -224,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None, conversation: Optional[Conversation] = None,
stream: bool = True, stream: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
""" """
Generate App response. Generate App response.

View File

@ -56,7 +56,7 @@ def _process_future(
class AppGeneratorTTSPublisher: class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str): def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.msg_text = "" self.msg_text = ""
@ -67,7 +67,7 @@ class AppGeneratorTTSPublisher:
self.model_instance = self.model_manager.get_default_model_instance( self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS tenant_id=self.tenant_id, model_type=ModelType.TTS
) )
self.voices = self.model_instance.get_tts_voices() self.voices = self.model_instance.get_tts_voices(language=language)
values = [voice.get("value") for voice in self.voices] values = [voice.get("value") for voice in self.voices]
self.voice = voice self.voice = voice
if not voice or voice not in values: if not voice or voice not in values:

View File

@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any, cast from typing import Any, cast
@ -58,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]: ) -> Generator[dict | str, Any, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -84,12 +83,12 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]: ) -> Generator[dict | str, Any, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -123,4 +122,4 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk

View File

@ -17,6 +17,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent, QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent, QueueAnnotationReplyEvent,
QueueErrorEvent, QueueErrorEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
@ -219,7 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline:
and features_dict["text_to_speech"].get("enabled") and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled" and features_dict["text_to_speech"].get("autoPlay") == "enabled"
): ):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
@ -247,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline:
else: else:
start_listener_time = time.time() start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e: except Exception:
logger.exception(f"Failed to listen audio message, task_id: {task_id}") logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break break
if tts_publisher: if tts_publisher:
@ -640,6 +643,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session.commit() session.commit()
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue

View File

@ -1,3 +1,4 @@
import contextvars
import logging import logging
import threading import threading
import uuid import uuid
@ -29,17 +30,6 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator): class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload @overload
def generate( def generate(
self, self,
@ -51,6 +41,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
streaming: Literal[False], streaming: Literal[False],
) -> Mapping[str, Any]: ... ) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload @overload
def generate( def generate(
self, self,
@ -60,7 +61,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ... ) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
def generate( def generate(
self, self,
@ -70,7 +71,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
): ) -> Union[Mapping, Generator[Mapping | str, None, None]]:
""" """
Generate App response. Generate App response.
@ -182,6 +183,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
target=self._generate_worker, target=self._generate_worker,
kwargs={ kwargs={
"flask_app": current_app._get_current_object(), # type: ignore "flask_app": current_app._get_current_object(), # type: ignore
"context": contextvars.copy_context(),
"application_generate_entity": application_generate_entity, "application_generate_entity": application_generate_entity,
"queue_manager": queue_manager, "queue_manager": queue_manager,
"conversation_id": conversation.id, "conversation_id": conversation.id,
@ -206,6 +208,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker( def _generate_worker(
self, self,
flask_app: Flask, flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentChatAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
@ -220,6 +223,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID :param message_id: message ID
:return: :return:
""" """
for var, val in context.items():
var.set(val)
with flask_app.app_context(): with flask_app.app_context():
try: try:
# get conversation and message # get conversation and message

View File

@ -16,10 +16,8 @@ from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,8 +62,8 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=dict(inputs),
files=files, files=list(files),
query=query, query=query,
) )
@ -86,8 +84,8 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=dict(inputs),
files=files, files=list(files),
query=query, query=query,
memory=memory, memory=memory,
) )
@ -99,8 +97,8 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=dict(inputs),
query=query, query=query or "",
message_id=message.id, message_id=message.id,
) )
except ModerationError as e: except ModerationError as e:
@ -156,9 +154,9 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=dict(inputs),
files=files, files=list(files),
query=query, query=query or "",
memory=memory, memory=memory,
) )
@ -173,16 +171,7 @@ class AgentChatAppRunner(AppRunner):
return return
agent_entity = app_config.agent agent_entity = app_config.agent
if not agent_entity: assert agent_entity is not None
raise ValueError("Agent entity not found")
# load tool variables
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
@ -193,17 +182,16 @@ class AgentChatAppRunner(AppRunner):
app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=dict(inputs),
files=files, files=list(files),
query=query, query=query or "",
memory=memory, memory=memory,
) )
# change function call strategy based on LLM model # change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema or not model_schema.features: assert model_schema is not None
raise ValueError("Model schema not found")
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
@ -243,8 +231,6 @@ class AgentChatAppRunner(AppRunner):
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
memory=memory, memory=memory,
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance, model_instance=model_instance,
) )
@ -262,50 +248,6 @@ class AgentChatAppRunner(AppRunner):
agent=True, agent=True,
) )
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables:
# save tool variables to session, so that we can update it later
db.session.add(tool_variables)
else:
# create new tool variables
tool_variables = ToolConversationVariables(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str="[]",
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)
def _get_usage_of_all_agent_thoughts( def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage: ) -> LLMUsage:

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
@ -51,10 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_stream_full_response( # type: ignore[override] def convert_stream_full_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -80,13 +79,12 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk
@classmethod @classmethod
def convert_stream_simple_response( # type: ignore[override] def convert_stream_simple_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -118,4 +116,4 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk

View File

@ -14,21 +14,15 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
def convert( def convert(
cls, cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
invoke_from: InvokeFrom,
) -> Mapping[str, Any] | Generator[str, None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:
def _generate_full_response() -> Generator[str, Any, None]: def _generate_full_response() -> Generator[dict | str, Any, None]:
for chunk in cls.convert_stream_full_response(response): yield from cls.convert_stream_full_response(response)
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
return _generate_full_response() return _generate_full_response()
else: else:
@ -36,12 +30,8 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_simple_response(response) return cls.convert_blocking_simple_response(response)
else: else:
def _generate_simple_response() -> Generator[str, Any, None]: def _generate_simple_response() -> Generator[dict | str, Any, None]:
for chunk in cls.convert_stream_simple_response(response): yield from cls.convert_stream_simple_response(response)
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
return _generate_simple_response() return _generate_simple_response()
@ -59,14 +49,14 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod @abstractmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]: ) -> Generator[dict | str, None, None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]: ) -> Generator[dict | str, None, None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@ -1,5 +1,6 @@
from collections.abc import Mapping, Sequence import json
from typing import TYPE_CHECKING, Any, Optional from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
from core.file import File, FileUploadConfig from core.file import File, FileUploadConfig
@ -138,3 +139,21 @@ class BaseAppGenerator:
if isinstance(value, str): if isinstance(value, str):
return value.replace("\x00", "") return value.replace("\x00", "")
return value return value
@classmethod
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
"""
Convert messages into event stream
"""
if isinstance(generator, dict):
return generator
else:
def gen():
for message in generator:
if isinstance(message, (Mapping, dict)):
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"
return gen()

View File

@ -2,7 +2,7 @@ import queue
import time import time
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Any from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
@ -115,7 +115,7 @@ class AppQueueManager:
Set task stop flag Set task stop flag
:return: :return:
""" """
result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None: if result is None:
return return

View File

@ -38,7 +38,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
) -> Generator[str, None, None]: ... ) -> Generator[Mapping | str, None, None]: ...
@overload @overload
def generate( def generate(
@ -58,7 +58,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate( def generate(
self, self,
@ -67,7 +67,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
): ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
""" """
Generate App response. Generate App response.

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
@ -52,9 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -80,13 +79,12 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -118,4 +116,4 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk

View File

@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
) -> Generator[str, None, None]: ... ) -> Generator[str | Mapping[str, Any], None, None]: ...
@overload @overload
def generate( def generate(
@ -56,8 +56,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool = False,
) -> Mapping[str, Any] | Generator[str, None, None]: ... ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate( def generate(
self, self,
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
): ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser], user: Union[Account, EndUser],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ) -> Union[Mapping, Generator[Mapping | str, None, None]]:
""" """
Generate App response. Generate App response.

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
CompletionAppStreamResponse, CompletionAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
@ -51,9 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -78,13 +77,12 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -115,4 +113,4 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk

View File

@ -36,13 +36,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser, user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
call_depth: int = 0, call_depth: int,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str],
) -> Generator[str, None, None]: ... ) -> Generator[Mapping | str, None, None]: ...
@overload @overload
def generate( def generate(
@ -50,12 +50,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser, user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
call_depth: int = 0, call_depth: int,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ... ) -> Mapping[str, Any]: ...
@overload @overload
@ -64,26 +64,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser, user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool,
call_depth: int = 0, call_depth: int,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any] | Generator[str, None, None]: ... ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate( def generate(
self, self,
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser, user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
): ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files # parse files
@ -124,7 +124,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
trace_manager=trace_manager, trace_manager=trace_manager,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -146,7 +149,18 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager # init queue manager
queue_manager = WorkflowAppQueueManager( queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
@ -185,10 +199,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user: Account, user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
streaming: bool = True, streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.
@ -224,6 +238,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
NodeFinishStreamResponse, NodeFinishStreamResponse,
NodeStartStreamResponse, NodeStartStreamResponse,
@ -36,9 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -62,13 +61,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, cls, stream_response: Generator[AppStreamResponse, None, None]
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[dict | str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -94,4 +92,4 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk) yield response_chunk

View File

@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent, QueueErrorEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
@ -190,7 +191,9 @@ class WorkflowAppGenerateTaskPipeline:
and features_dict["text_to_speech"].get("enabled") and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled" and features_dict["text_to_speech"].get("autoPlay") == "enabled"
): ):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
@ -527,6 +530,10 @@ class WorkflowAppGenerateTaskPipeline:
yield self._text_chunk_to_stream_response( yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector delta_text, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue

View File

@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueAgentLogEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
QueueIterationStartEvent, QueueIterationStartEvent,
@ -27,6 +28,7 @@ from core.app.entities.queue_entities import (
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent, GraphEngineEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
@ -373,6 +375,19 @@ class WorkflowBasedAppRunner(AppRunner):
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
) )
) )
elif isinstance(event, AgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event( self._publish_event(
QueueParallelBranchRunStartedEvent( QueueParallelBranchRunStartedEvent(

View File

@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
""" """
node_id: str node_id: str
inputs: dict inputs: Mapping
single_iteration_run: Optional[SingleIterationRunEntity] = None single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@ -41,6 +41,7 @@ class QueueEvent(StrEnum):
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error" ERROR = "error"
PING = "ping" PING = "ping"
STOP = "stop" STOP = "stop"
@ -315,6 +316,22 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None iteration_duration_map: Optional[dict[str, float]] = None
class QueueAgentLogEvent(AppQueueEvent):
"""
QueueAgentLogEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_LOG
id: str
label: str
node_execution_id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
class QueueNodeRetryEvent(QueueNodeStartedEvent): class QueueNodeRetryEvent(QueueNodeStartedEvent):
"""QueueNodeRetryEvent entity""" """QueueNodeRetryEvent entity"""

View File

@ -60,6 +60,7 @@ class StreamEvent(Enum):
ITERATION_COMPLETED = "iteration_completed" ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk" TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace" TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
@ -696,3 +697,26 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
node_execution_id: str
id: str
label: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
event: StreamEvent = StreamEvent.AGENT_LOG
data: Data

View File

@ -24,6 +24,8 @@ class HostingModerationFeature:
if isinstance(prompt_message.content, str): if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n" text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(model_config, text) moderation_result = moderation.check_moderation(
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
)
return moderation_result return moderation_result

View File

@ -215,7 +215,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
and text_to_speech_dict.get("autoPlay") == "enabled" and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled") and text_to_speech_dict.get("enabled")
): ):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) publisher = AppGeneratorTTSPublisher(
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
)
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listen_audio_msg(publisher, task_id) audio_response = self._listen_audio_msg(publisher, task_id)

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
QueueIterationStartEvent, QueueIterationStartEvent,
@ -24,6 +25,7 @@ from core.app.entities.queue_entities import (
QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunSucceededEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse, IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse, IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse, IterationNodeStartStreamResponse,
@ -320,9 +322,8 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = ( execution_metadata_dict = dict(event.execution_metadata or {})
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
)
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
@ -843,3 +844,24 @@ class WorkflowCycleManage:
raise ValueError(f"Workflow node execution not found: {node_execution_id}") raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution return cached_workflow_node_execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
),
)

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping, Sequence from collections.abc import Iterable, Mapping
from typing import Any, Optional, TextIO, Union from typing import Any, Optional, TextIO, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -57,7 +57,7 @@ class DifyAgentCallbackHandler(BaseModel):
self, self,
tool_name: str, tool_name: str,
tool_inputs: Mapping[str, Any], tool_inputs: Mapping[str, Any],
tool_outputs: Sequence[ToolInvokeMessage] | str, tool_outputs: Iterable[ToolInvokeMessage] | str,
message_id: Optional[str] = None, message_id: Optional[str] = None,
timer: Optional[Any] = None, timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,

View File

@ -1,5 +1,26 @@
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from collections.abc import Generator, Iterable, Mapping
from typing import Any, Optional
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.entities.tool_entities import ToolInvokeMessage
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
def on_tool_execution(
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Iterable[ToolInvokeMessage],
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[ToolInvokeMessage, None, None]:
for tool_output in tool_outputs:
print_text("\n[on_tool_execution]\n", color=self.color)
print_text("Tool: " + tool_name + "\n", color=self.color)
print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color)
print_text("\n")
yield tool_output

View File

@ -0,0 +1 @@
DEFAULT_PLUGIN_ID = "langgenius"

View File

@ -0,0 +1,42 @@
from enum import StrEnum
class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
# TOOL_SELECTOR = "tool-selector"
class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
class ModelSelectorScope(StrEnum):
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"

View File

@ -2,13 +2,14 @@ import datetime
import json import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator, Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import ( from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
@ -18,16 +19,15 @@ from core.entities.provider_entities import (
) )
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import ( from core.model_runtime.entities.provider_entities import (
ConfigurateMethod, ConfigurateMethod,
CredentialFormSchema, CredentialFormSchema,
FormType, FormType,
ProviderEntity, ProviderEntity,
) )
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import ( from models.provider import (
LoadBalancingModelConfig, LoadBalancingModelConfig,
@ -99,9 +99,10 @@ class ProviderConfiguration(BaseModel):
continue continue
restrict_models = quota_configuration.restrict_models restrict_models = quota_configuration.restrict_models
if self.system_configuration.credentials is None:
return None copy_credentials = (
copy_credentials = self.system_configuration.credentials.copy() self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
)
if restrict_models: if restrict_models:
for restrict_model in restrict_models: for restrict_model in restrict_models:
if ( if (
@ -140,6 +141,9 @@ class ProviderConfiguration(BaseModel):
if current_quota_configuration is None: if current_quota_configuration is None:
return None return None
if not current_quota_configuration:
return SystemConfigurationStatus.UNSUPPORTED
return ( return (
SystemConfigurationStatus.ACTIVE SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid if current_quota_configuration.is_valid
@ -153,7 +157,7 @@ class ProviderConfiguration(BaseModel):
""" """
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False): def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
""" """
Get custom credentials. Get custom credentials.
@ -175,7 +179,7 @@ class ProviderConfiguration(BaseModel):
else [], else [],
) )
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
""" """
Validate custom credentials. Validate custom credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -219,6 +223,7 @@ class ProviderConfiguration(BaseModel):
if value == HIDDEN_VALUE and key in original_credentials: if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.provider_credentials_validate( credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials provider=self.provider.provider, credentials=credentials
) )
@ -246,13 +251,13 @@ class ProviderConfiguration(BaseModel):
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_record = Provider( provider_record = Provider()
tenant_id=self.tenant_id, provider_record.tenant_id = self.tenant_id
provider_name=self.provider.provider, provider_record.provider_name = self.provider.provider
provider_type=ProviderType.CUSTOM.value, provider_record.provider_type = ProviderType.CUSTOM.value
encrypted_config=json.dumps(credentials), provider_record.encrypted_config = json.dumps(credentials)
is_valid=True, provider_record.is_valid = True
)
db.session.add(provider_record) db.session.add(provider_record)
db.session.commit() db.session.commit()
@ -327,7 +332,7 @@ class ProviderConfiguration(BaseModel):
def custom_model_credentials_validate( def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict self, model_type: ModelType, model: str, credentials: dict
) -> tuple[Optional[ProviderModel], dict]: ) -> tuple[ProviderModel | None, dict]:
""" """
Validate custom model credentials. Validate custom model credentials.
@ -370,6 +375,7 @@ class ProviderConfiguration(BaseModel):
if value == HIDDEN_VALUE and key in original_credentials: if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.model_credentials_validate( credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
) )
@ -400,14 +406,13 @@ class ProviderConfiguration(BaseModel):
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_model_record = ProviderModel( provider_model_record = ProviderModel()
tenant_id=self.tenant_id, provider_model_record.tenant_id = self.tenant_id
provider_name=self.provider.provider, provider_model_record.provider_name = self.provider.provider
model_name=model, provider_model_record.model_name = model
model_type=model_type.to_origin_model_type(), provider_model_record.model_type = model_type.to_origin_model_type()
encrypted_config=json.dumps(credentials), provider_model_record.encrypted_config = json.dumps(credentials)
is_valid=True, provider_model_record.is_valid = True
)
db.session.add(provider_model_record) db.session.add(provider_model_record)
db.session.commit() db.session.commit()
@ -474,13 +479,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting()
tenant_id=self.tenant_id, model_setting.tenant_id = self.tenant_id
provider_name=self.provider.provider, model_setting.provider_name = self.provider.provider
model_type=model_type.to_origin_model_type(), model_setting.model_type = model_type.to_origin_model_type()
model_name=model, model_setting.model_name = model
enabled=True, model_setting.enabled = True
)
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
@ -509,13 +513,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting()
tenant_id=self.tenant_id, model_setting.tenant_id = self.tenant_id
provider_name=self.provider.provider, model_setting.provider_name = self.provider.provider
model_type=model_type.to_origin_model_type(), model_setting.model_type = model_type.to_origin_model_type()
model_name=model, model_setting.model_name = model
enabled=False, model_setting.enabled = False
)
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
@ -576,13 +579,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting()
tenant_id=self.tenant_id, model_setting.tenant_id = self.tenant_id
provider_name=self.provider.provider, model_setting.provider_name = self.provider.provider
model_type=model_type.to_origin_model_type(), model_setting.model_type = model_type.to_origin_model_type()
model_name=model, model_setting.model_name = model
load_balancing_enabled=True, model_setting.load_balancing_enabled = True
)
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
@ -611,25 +613,17 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting()
tenant_id=self.tenant_id, model_setting.tenant_id = self.tenant_id
provider_name=self.provider.provider, model_setting.provider_name = self.provider.provider
model_type=model_type.to_origin_model_type(), model_setting.model_type = model_type.to_origin_model_type()
model_name=model, model_setting.model_name = model
load_balancing_enabled=False, model_setting.load_balancing_enabled = False
)
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()
return model_setting return model_setting
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
:return:
"""
return model_provider_factory.get_provider_instance(self.provider.provider)
def get_model_type_instance(self, model_type: ModelType) -> AIModel: def get_model_type_instance(self, model_type: ModelType) -> AIModel:
""" """
Get current model type instance. Get current model type instance.
@ -637,11 +631,19 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type :param model_type: model type
:return: :return:
""" """
# Get provider instance model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_instance = self.get_provider_instance()
# Get model instance of LLM # Get model instance of LLM
return provider_instance.get_model_instance(model_type) return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
""" """
@ -668,11 +670,10 @@ class ProviderConfiguration(BaseModel):
if preferred_model_provider: if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value preferred_model_provider.preferred_provider_type = provider_type.value
else: else:
preferred_model_provider = TenantPreferredModelProvider( preferred_model_provider = TenantPreferredModelProvider()
tenant_id=self.tenant_id, preferred_model_provider.tenant_id = self.tenant_id
provider_name=self.provider.provider, preferred_model_provider.provider_name = self.provider.provider
preferred_provider_type=provider_type.value, preferred_model_provider.preferred_provider_type = provider_type.value
)
db.session.add(preferred_model_provider) db.session.add(preferred_model_provider)
db.session.commit() db.session.commit()
@ -737,13 +738,14 @@ class ProviderConfiguration(BaseModel):
:param only_active: only active models :param only_active: only active models
:return: :return:
""" """
provider_instance = self.get_provider_instance() model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types = [] model_types: list[ModelType] = []
if model_type: if model_type:
model_types.append(model_type) model_types.append(model_type)
else: else:
model_types = list(provider_instance.get_provider_schema().supported_model_types) model_types = list(provider_schema.supported_model_types)
# Group model settings by model type and model # Group model settings by model type and model
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
@ -752,11 +754,11 @@ class ProviderConfiguration(BaseModel):
if self.using_provider_type == ProviderType.SYSTEM: if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models( provider_models = self._get_system_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
) )
else: else:
provider_models = self._get_custom_provider_models( provider_models = self._get_custom_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
) )
if only_active: if only_active:
@ -767,23 +769,26 @@ class ProviderConfiguration(BaseModel):
def _get_system_provider_models( def _get_system_provider_models(
self, self,
model_types: list[ModelType], model_types: Sequence[ModelType],
provider_instance: ModelProvider, provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]], model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]: ) -> list[ModelWithProviderEntity]:
""" """
Get system provider models. Get system provider models.
:param model_types: model types :param model_types: model types
:param provider_instance: provider instance :param provider_schema: provider schema
:param model_setting_map: model setting map :param model_setting_map: model setting map
:return: :return:
""" """
provider_models = [] provider_models = []
for model_type in model_types: for model_type in model_types:
for m in provider_instance.models(model_type): for m in provider_schema.models:
if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: if m.model in model_setting_map:
model_setting = model_setting_map[m.model_type][m.model] model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False: if model_setting.enabled is False:
status = ModelStatus.DISABLED status = ModelStatus.DISABLED
@ -804,7 +809,7 @@ class ProviderConfiguration(BaseModel):
if self.provider.provider not in original_provider_configurate_methods: if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = [] original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods: for configurate_method in provider_schema.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method) original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False should_use_custom_model = False
@ -825,18 +830,22 @@ class ProviderConfiguration(BaseModel):
]: ]:
# only customizable model # only customizable model
for restrict_model in restrict_models: for restrict_model in restrict_models:
if self.system_configuration.credentials is not None: copy_credentials = (
copy_credentials = self.system_configuration.credentials.copy() self.system_configuration.credentials.copy()
if restrict_model.base_model_name: if self.system_configuration.credentials
copy_credentials["base_model_name"] = restrict_model.base_model_name else {}
)
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
try: try:
custom_model_schema = provider_instance.get_model_instance( custom_model_schema = self.get_model_schema(
restrict_model.model_type model_type=restrict_model.model_type,
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) model=restrict_model.model,
except Exception as ex: credentials=copy_credentials,
logger.warning(f"get custom model schema failed, {ex}") )
continue except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
if not custom_model_schema: if not custom_model_schema:
continue continue
@ -881,15 +890,15 @@ class ProviderConfiguration(BaseModel):
def _get_custom_provider_models( def _get_custom_provider_models(
self, self,
model_types: list[ModelType], model_types: Sequence[ModelType],
provider_instance: ModelProvider, provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]], model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]: ) -> list[ModelWithProviderEntity]:
""" """
Get custom provider models. Get custom provider models.
:param model_types: model types :param model_types: model types
:param provider_instance: provider instance :param provider_schema: provider schema
:param model_setting_map: model setting map :param model_setting_map: model setting map
:return: :return:
""" """
@ -903,8 +912,10 @@ class ProviderConfiguration(BaseModel):
if model_type not in self.provider.supported_model_types: if model_type not in self.provider.supported_model_types:
continue continue
models = provider_instance.models(model_type) for m in provider_schema.models:
for m in models: if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False load_balancing_enabled = False
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
@ -936,10 +947,10 @@ class ProviderConfiguration(BaseModel):
continue continue
try: try:
custom_model_schema = provider_instance.get_model_instance( custom_model_schema = self.get_model_schema(
model_configuration.model_type model_type=model_configuration.model_type,
).get_customizable_model_schema_from_credentials( model=model_configuration.model,
model_configuration.model, model_configuration.credentials credentials=model_configuration.credentials,
) )
except Exception as ex: except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}") logger.warning(f"get custom model schema failed, {ex}")
@ -967,7 +978,7 @@ class ProviderConfiguration(BaseModel):
label=custom_model_schema.label, label=custom_model_schema.label,
model_type=custom_model_schema.model_type, model_type=custom_model_schema.model_type,
features=custom_model_schema.features, features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties=custom_model_schema.model_properties, model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated, deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
@ -1040,6 +1051,9 @@ class ProviderConfigurations(BaseModel):
return list(self.values()) return list(self.values())
def __getitem__(self, key): def __getitem__(self, key):
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations[key] return self.configurations[key]
def __setitem__(self, key, value): def __setitem__(self, key, value):
@ -1051,8 +1065,11 @@ class ProviderConfigurations(BaseModel):
def values(self) -> Iterator[ProviderConfiguration]: def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values()) return iter(self.configurations.values())
def get(self, key, default=None): def get(self, key, default=None) -> ProviderConfiguration | None:
return self.configurations.get(key, default) if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations.get(key, default) # type: ignore
class ProviderModelBundle(BaseModel): class ProviderModelBundle(BaseModel):
@ -1061,7 +1078,6 @@ class ProviderModelBundle(BaseModel):
""" """
configuration: ProviderConfiguration configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel model_type_instance: AIModel
# pydantic configs # pydantic configs

View File

@ -1,10 +1,34 @@
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Union
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import (
AppSelectorScope,
CommonParameterType,
ModelSelectorScope,
ToolSelectorScope,
)
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType from core.tools.entities.common_entities import I18nObject
class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum): class QuotaUnit(Enum):
@ -108,3 +132,55 @@ class ModelSettings(BaseModel):
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class BasicProviderConfig(BaseModel):
"""
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials")
class ProviderConfig(BasicProviderConfig):
"""
Model class for common provider settings like credentials
"""
class Option(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
def to_basic_provider_config(self) -> BasicProviderConfig:
return BasicProviderConfig(type=self.type, name=self.name)

View File

@ -20,6 +20,41 @@ def get_signed_file_url(upload_file_id: str) -> str:
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
url = f"{dify_config.FILES_URL}/files/upload/for-plugin"
if user_id is None:
user_id = "DEFAULT-USER"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = dify_config.SECRET_KEY.encode()
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}"
def verify_plugin_file_signature(
*, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str
) -> bool:
if user_id is None:
user_id = "DEFAULT-USER"
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() secret_key = dify_config.SECRET_KEY.encode()

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Optional from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@ -124,6 +124,17 @@ class File(BaseModel):
tool_file_id=self.related_id, extension=self.extension tool_file_id=self.related_id, extension=self.extension
) )
def to_plugin_parameter(self) -> dict[str, Any]:
return {
"dify_model_identity": FILE_MODEL_IDENTITY,
"mime_type": self.mime_type,
"filename": self.filename,
"extension": self.extension,
"size": self.size,
"type": self.type,
"url": self.generate_url(),
}
@model_validator(mode="after") @model_validator(mode="after")
def validate_after(self): def validate_after(self):
match self.transfer_method: match self.transfer_method:

View File

@ -0,0 +1,69 @@
import base64
import logging
import time
from typing import Optional
from configs import dify_config
from core.helper.url_signer import UrlSigner
from extensions.ext_storage import storage
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.exception(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
"""
get signed url from upload file
:param upload_file: UploadFile object
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
result = UrlSigner.verify(
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
)
# verify signature
if not result:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View File

@ -0,0 +1,17 @@
from core.helper import ssrf_proxy
def download_with_size_limit(url, max_download_size: int, **kwargs):
response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
if response.status_code == 404:
raise ValueError("file not found")
total_size = 0
chunks = []
for chunk in response.iter_bytes():
total_size += len(chunk)
if total_size > max_download_size:
raise ValueError("Max file size reached")
chunks.append(chunk)
content = b"".join(chunks)
return content

View File

@ -0,0 +1,35 @@
from collections.abc import Sequence
import requests
from yarl import URL
from configs import dify_config
from core.helper.download import download_with_size_limit
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
def get_plugin_pkg_url(plugin_unique_identifier: str):
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
unique_identifier=plugin_unique_identifier
)
def download_plugin_pkg(plugin_unique_identifier: str):
url = str(get_plugin_pkg_url(plugin_unique_identifier))
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

View File

@ -1,28 +1,35 @@
import logging import logging
import random import random
from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_hosting_provider import hosting_configuration from extensions.ext_hosting_provider import hosting_configuration
from models.provider import ProviderType from models.provider import ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config moderation_config = hosting_configuration.moderation_config
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
if ( if (
moderation_config moderation_config
and moderation_config.enabled is True and moderation_config.enabled is True
and "openai" in hosting_configuration.provider_map and openai_provider_name in hosting_configuration.provider_map
and hosting_configuration.provider_map["openai"].enabled is True and hosting_configuration.provider_map[openai_provider_name].enabled is True
): ):
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
provider_name = model_config.provider provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"] hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]
assert hosting_openai_config is not None
if hosting_openai_config.credentials is None:
return False
# 2000 text per chunk # 2000 text per chunk
length = 2000 length = 2000
@ -34,15 +41,20 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
text_chunk = random.choice(text_chunks) text_chunk = random.choice(text_chunks)
try: try:
model_type_instance = OpenAIModerationModel() model_provider_factory = ModelProviderFactory(tenant_id)
# FIXME, for type hint using assert or raise ValueError is better here?
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)
moderation_result = model_type_instance.invoke( moderation_result = model_type_instance.invoke(
model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
) )
if moderation_result is True: if moderation_result is True:
return True return True
except Exception as ex: except Exception:
logger.exception(f"Fails to check moderation, provider_name: {provider_name}") logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
raise InvokeBadRequestError("Rate limit exceeded, please try again later.") raise InvokeBadRequestError("Rate limit exceeded, please try again later.")

View File

@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
) )
retries = 0 retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries: while retries <= max_retries:
try: try:
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:
@ -61,17 +60,20 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if response.status_code not in STATUS_FORCELIST: if response.status_code not in STATUS_FORCELIST:
return response return response
else: else:
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") logging.warning(
f"Received status code {response.status_code} for URL {url} which is in the force list")
except httpx.RequestError as e: except httpx.RequestError as e:
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") logging.warning(f"Request to URL {url} failed on attempt {
retries + 1}: {e}")
if max_retries == 0: if max_retries == 0:
raise raise
retries += 1 retries += 1
if retries <= max_retries: if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") raise MaxRetriesExceededError(
f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

View File

@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum): class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider" PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache: class ToolProviderCredentialsCache:

View File

@ -0,0 +1,52 @@
import base64
import hashlib
import hmac
import os
import time
from pydantic import BaseModel, Field
from configs import dify_config
class SignedUrlParams(BaseModel):
sign_key: str = Field(..., description="The sign key")
timestamp: str = Field(..., description="Timestamp")
nonce: str = Field(..., description="Nonce")
sign: str = Field(..., description="Signature")
class UrlSigner:
@classmethod
def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
signed_url_params = cls.get_signed_url_params(sign_key, prefix)
return (
f"{url}?timestamp={signed_url_params.timestamp}"
f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
)
@classmethod
def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
sign = cls._sign(sign_key, timestamp, nonce, prefix)
return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)
@classmethod
def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)
return sign == recalculated_sign
@classmethod
def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
if not dify_config.SECRET_KEY:
raise Exception("SECRET_KEY is not set")
data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return encoded_sign

View File

@ -4,9 +4,9 @@ from flask import Flask
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config from configs import dify_config
from core.entities.provider_entities import QuotaUnit, RestrictModel from core.entities import DEFAULT_PLUGIN_ID
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
class HostingQuota(BaseModel): class HostingQuota(BaseModel):
@ -48,12 +48,12 @@ class HostingConfiguration:
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
return return
self.provider_map["azure_openai"] = self.init_azure_openai() self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai() self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic() self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax() self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map["spark"] = self.init_spark() self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map["zhipuai"] = self.init_zhipuai() self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.moderation_config = self.init_moderation_config() self.moderation_config = self.init_moderation_config()
@ -240,7 +240,14 @@ class HostingConfiguration:
@staticmethod @staticmethod
def init_moderation_config() -> HostedModerationConfig: def init_moderation_config() -> HostedModerationConfig:
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(",")) providers = dify_config.HOSTED_MODERATION_PROVIDERS.split(",")
hosted_providers = []
for provider in providers:
if "/" not in provider:
provider = f"{DEFAULT_PLUGIN_ID}/{provider}/{provider}"
hosted_providers.append(provider)
return HostedModerationConfig(enabled=True, providers=hosted_providers)
return HostedModerationConfig(enabled=False) return HostedModerationConfig(enabled=False)

View File

@ -30,7 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter,
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.web_reader_tool import get_image_upload_file_ids from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -608,10 +608,8 @@ class IndexingRunner:
tokens = 0 tokens = 0
if embedding_model_instance: if embedding_model_instance:
tokens += sum( page_content_list = [document.page_content for document in chunk_documents]
embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
for document in chunk_documents
)
# load index # load index
index_processor.load(dataset, chunk_documents, with_keywords=False) index_processor.load(dataset, chunk_documents, with_keywords=False)

View File

@ -48,7 +48,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
), ),
) )
answer = cast(str, response.message.content) answer = cast(str, response.message.content)
@ -101,7 +101,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0}, model_parameters={"max_tokens": 256, "temperature": 0},
stream=False, stream=False,
), ),
@ -110,7 +110,7 @@ class LLMGenerator:
questions = output_parser.parse(cast(str, response.message.content)) questions = output_parser.parse(cast(str, response.message.content))
except InvokeError: except InvokeError:
questions = [] questions = []
except Exception as e: except Exception:
logging.exception("Failed to generate suggested questions after answer") logging.exception("Failed to generate suggested questions after answer")
questions = [] questions = []
@ -150,7 +150,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )
@ -200,7 +200,7 @@ class LLMGenerator:
prompt_content = cast( prompt_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )
except InvokeError as e: except InvokeError as e:
@ -236,7 +236,7 @@ class LLMGenerator:
parameter_content = cast( parameter_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
), ),
) )
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
@ -248,7 +248,7 @@ class LLMGenerator:
statement_content = cast( statement_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
), ),
) )
rule_config["opening_statement"] = cast(str, statement_content.message.content) rule_config["opening_statement"] = cast(str, statement_content.message.content)
@ -301,7 +301,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )

View File

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Callable, Generator, Iterable, Sequence from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Optional, Union, cast from typing import IO, Any, Literal, Optional, Union, cast, overload
from configs import dify_config from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
@ -98,6 +98,42 @@ class ModelInstance:
return None return None
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[True] = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator: ...
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[False] = False,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResult: ...
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]: ...
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
@ -192,7 +228,7 @@ class ModelInstance:
), ),
) )
def get_text_embedding_num_tokens(self, texts: list[str]) -> int: def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
""" """
Get number of tokens for text embedding Get number of tokens for text embedding
@ -204,7 +240,7 @@ class ModelInstance:
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
int, list[int],
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,
@ -397,7 +433,7 @@ class ModelManager:
return ModelInstance(provider_model_bundle, model) return ModelInstance(provider_model_bundle, model)
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
""" """
Return first provider and the first model in the provider Return first provider and the first model in the provider
:param tenant_id: tenant id :param tenant_id: tenant id

View File

@ -18,7 +18,6 @@ class ModelType(Enum):
SPEECH2TEXT = "speech2text" SPEECH2TEXT = "speech2text"
MODERATION = "moderation" MODERATION = "moderation"
TTS = "tts" TTS = "tts"
TEXT2IMG = "text2img"
@classmethod @classmethod
def value_of(cls, origin_model_type: str) -> "ModelType": def value_of(cls, origin_model_type: str) -> "ModelType":
@ -37,8 +36,6 @@ class ModelType(Enum):
return cls.SPEECH2TEXT return cls.SPEECH2TEXT
elif origin_model_type in {"tts", cls.TTS.value}: elif origin_model_type in {"tts", cls.TTS.value}:
return cls.TTS return cls.TTS
elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value: elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION return cls.MODERATION
else: else:
@ -62,8 +59,6 @@ class ModelType(Enum):
return "tts" return "tts"
elif self == self.MODERATION: elif self == self.MODERATION:
return "moderation" return "moderation"
elif self == self.TEXT2IMG:
return "text2img"
else: else:
raise ValueError(f"invalid model type {self}") raise ValueError(f"invalid model type {self}")

Some files were not shown because too many files have changed in this diff Show More