From 296887754f9dbec46cce9445a0fb64b1f6303f7d Mon Sep 17 00:00:00 2001 From: Patryk Garstecki Date: Fri, 24 May 2024 06:01:40 +0200 Subject: [PATCH] Support for Vertex AI (#4586) --- .../model_providers/_position.yaml | 1 + .../model_providers/vertex_ai/__init__.py | 0 .../vertex_ai/_assets/icon_l_en.png | Bin 0 -> 18078 bytes .../vertex_ai/_assets/icon_s_en.svg | 1 + .../model_providers/vertex_ai/_common.py | 15 + .../model_providers/vertex_ai/llm/__init__.py | 0 .../vertex_ai/llm/gemini-1.0-pro-vision.yaml | 38 ++ .../vertex_ai/llm/gemini-1.0-pro.yaml | 38 ++ .../vertex_ai/llm/gemini-1.5-flash.yaml | 38 ++ .../vertex_ai/llm/gemini-1.5-pro.yaml | 39 ++ .../model_providers/vertex_ai/llm/llm.py | 438 ++++++++++++++++++ .../vertex_ai/text_embedding/__init__.py | 0 .../text_embedding/text-embedding-004.yaml | 8 + .../text-multilingual-embedding-002.yaml | 8 + .../text_embedding/text_embedding.py | 193 ++++++++ .../model_providers/vertex_ai/vertex_ai.py | 31 ++ .../model_providers/vertex_ai/vertex_ai.yaml | 43 ++ api/requirements.txt | 1 + 18 files changed, 892 insertions(+) create mode 100644 api/core/model_runtime/model_providers/vertex_ai/__init__.py create mode 100644 api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png create mode 100644 api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg create mode 100644 api/core/model_runtime/model_providers/vertex_ai/_common.py create mode 100644 api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml create mode 100644 api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml create mode 100644 api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml create mode 100644 api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml create mode 100644 api/core/model_runtime/model_providers/vertex_ai/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml create mode 100644 api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml create mode 100644 api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py create mode 100644 api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py create mode 100644 api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 1c27b2b4aa..a868cb8c78 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -2,6 +2,7 @@ - anthropic - azure_openai - google +- vertex_ai - nvidia - cohere - bedrock diff --git a/api/core/model_runtime/model_providers/vertex_ai/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png new file mode 100644 index 0000000000000000000000000000000000000000..9f8f05231a6733a138502ba7a254e2125ae8d02b GIT binary patch literal 18078 zcmb??g;N~e6YVbU?hrI2xVyUqcemi~Zi@wn1oz+sceh1?yGyV@AUMIDH{b8o`ww1G z%ff8c&fLD;xBHwv9j&S?i;hBq0s?{1<>jQ*K_D1q;BN~g1mL&Jzuq_y2pn!JDXA(i zDM{(#=4@r_UWh@u3Q7%WpA0&OUv+lWz(VabKDDfhj_Gt@(2yFp!@cWwTNF7YvidPcjD z(se6=Oss}>Evn}jEFUTMPl!QG?@l> z?qz$`FvBkMkbX&N)7hvo(daSbxMCT9)O&(@IBXb;1yvb%p5IMJ#kz9yLH)y`+mF(2 zqk9aJRKqY8GnDN?b)xqh)mj-j7fIB}l-IxcC?v&v2yK?$9XzeLW_=X9Aa0vmf3<_n zczEBvx%639Y>UUmbz^Byc-sx}Z!pJu4Yv|Xr5aVx=JeleH?eDizj{}G_NTko;1N}w za?EFu>5+m7kZPC@^tJH^UgicTlD!R1EyU|?GiJ}bydf!_Z6$?q;p2Cg8x$mzL* zK&ZI?y}+QX93tQ%qKCYaG~!=)JX{t!a10a)xJ2wBt?MD_?C5CeL3Be!)fha-pQsSCEYp1!^rldcYzHBuzj?yZ#Vj*FD`b&!ko320?^8P(m zOo#${3(cr8?ZS)*8$F%|Yq4wtQ~OM12v*Wph~#)l#-=|qeKGH90_|zec_i6yzPV2g z?^`FXb~o#m&Zed&rh44>C-@;!)ZgUZ<3#`edb9PX`<7g3_v{awfn>@CNd=LQkG6De zsWF<{)YpnaIzeQJrreOsMQJ_fQK$cV>-SVUnp(T(u-ghCe{OyzP)>5PquwvDv<55a zJM2p9>Xo6UQ3|(?7ykc!$WJf_1U$G(>$SY?VM`IsnKo;$LU-f;@4M7?Jbc1;WxZxP zi12pqg9gr4?UIb7|92yK>QE@z4xi^t$308`zwZPpMxU+WTqxid;uKyuu-Ck>m0=Z!h|*CMq!asz%cXbs_j~EaMjtyf|1+}yM>JogDCYJY zTku*^vg6?<_Wz6{2r&R#!0N)A%6O_RS$CQb-=7qIgZ1NCFwP=Ciqq81DXo)J*z;2i z&o2y`DxNUOUcHGhH!;3133&P6{zYgY5-4{HDzFM~HQg;K5oYq)w8Fs-sQcDJ>>iar zHW%|SS6|k(J3FQ|vjSV5lmMZw6m!cin(1;|pY)1?(d4U#1k0Ljr3}j$rDk-16sGbEAws{|`23Nt4hr5Pz3H zk_Hdtj-Rnr-c#u=_~p&-j9AC=_0qDQ;reCm_iM#WCrleRrz)@0YLo_rZhGp?UsqP< z6pa1R>!A3mUa9imYlE+Lk4Z6e{@D?eui>I6wQX@=Jah z*U|XT7jl*O;Hk+75ejYJL;UL+OYGqwzu(7}HiCs_PDCWZLJlrdbRb@H>)zTvO2uP` z_}AePst7~}!>d$zI2C*EcS%3RZTr6PrscXHEjY_2){W?B{_kt0`2tCI+Ggk0AFvZl zlJ*Wr-tF9%AROf2@7#|}&8>VlVe1%gaDRQ#EDg?9F#EP@JT5;WhspJ8#eIGM?v2Ft zsOOS9^*T&O*V`9Wt_?G8szomkC zUwcAl$%8DsUWF*~a6k52<4+~Se;bU!oa5HOObQafhAM&-z+FhDa>^|wYuinPG*MY) z$=tZV;U$L_fjCIcuTVJ8RyIg{eqLzY z@*cl>gEl({kb8t{2loc5e$f1c8T&!zwe`XCpOG-bAA3~Y(*)xVXaUMUG4QnxbOSC~ zk7@z6lo{pU_t5$(6NDaLpR{pSdqu#)Ns`%nsg;j|TexxJprw5agCLc~z$u1XPwQ<{ z&~Zg<1KK$!@D#$KNY|O6g$ISB8v=&9p%hc3OTCp_cX@i4@cLfg@-`0v7 z8)}Z$Sq!<#&$~8GrJ3nw=E7iTJcj6Crfg1@>#^Q%jg7Wc{r!PIyBa<}T{JZ{f^xs99TD_d);`wqrFN)0a^K7zcXmNaXfgBVkqY(W z%{sIC&dSMJem)c(%@94`AIJFKSu8_Zm0?`i-J+{%10s}d&H#NGcwtuB8UFN(mSS8# zN$&Hr}G*!2RE;mL#wP zw~Z2Ra@5I@nYW4oGkzs0JDP@L-q zbODBx_6QTAG2iAI7^U3ESUb(svXTXfr8&>XX>JWWbFji*FtP_#^BVBr=w^5dhh-pU z+-!?sAuvh|cG)*Z)A%U%5*8;A1)U=}&JG&aOwD;c@C+NHgWs;uo;&fK{h6TL4?L&) zZi)6SWH_{)xV-UZQW7c^Hrl1DlR~Itaie>EPH1dkb$xtX!z`D5ag15^gNcRXID}D)%1`$3 zv_VGpLRWg;%j4(s2)}o2faA_yx3H7zwM!+es759%BK3N~R z*`#=A*N_Ke>XB$;EEAjaa+z|G7%8>sfwDimW%h|-f)K$(@78jwohKXZSaD`^+!C*n zea64qFZRH@_%sfPYWo>QjuF@io^?4n+~IqsDRlzVN?>_VGu` z@&z^Z@ykl*(aY|AQ96A^ZzFMtyRI?zi zYb(!&(BO6%>L6TZ@{2B4FmqtPP9MAeY&6$O{FhPe8BOow6&s<{b5u(< z3IVune7k9jn(S?LJLfy_wmBk%bAI(-RupPlv6iOGw@x1#m<ng6sou0ctj@H2%~i;Otmzys8memTjAw-9P zala;GH>kghpTTG^=U+|prq11M@zT@Tt`;N~!-n*MmMEJlYXtD|)s0fLA`@DK?^=*6 z9v|ME6q0P#I%tc=Bf|EX%ZK{!)YX3t9P~WA{KI4wDexZnLQLry&~N$S&$XI&EuYit zUbmp4Ls>nKqe<>Rb+p2AWvXBGvK_%oVU1(@WIU!{=8Z5-vnJ2zFj4Zx#9nWE>BAos zKhS5r&y5K|MuFDZ?s2984vqFS zLM1h5q3W4yU}}}W|EJt+aNK0000DcO1Z`J#K}^mnZDX-bk_U;?J7?}%`-|=o z=!4qeoHK|a>Tf%6LX`c$CGRv3rn>(Gia(d@rHI zzf=U8;w9??3xl4cAF1q3E3Lc|Y1QW7DtNZa4mx|aPG_kC!V`y3@B%Ch8-+gR6{K3# z-wCF?K7g=2lBGlE(TS{Xkm7{nt`_}?N#w)bWXe82u0@BqWgtcF!WC|3g5T(kqQ3MQ0b(au3;*riVb=sE_qT1z}| z-#_!Dej4IE<%t$p>=ynxJ+v4lqg-zMj!zks9P&H=OIOM!OtFJ z&?}R`8z_rSvyTx=O48OT3W5v%0+ZBHjp%?Z(M zDmUtbRb84a=6lO3`HBWl4O7P&Qqs}_s3OJWtp^nQ{NZ-zq9ozjn2qqg)7mzey?)ow z<>G_btgANuHDEnUT&>t!{eBY8!=cX-pa)=?-3oVHuz*B;Qk>MhM|g{>4#-EtEI=2) zc8*M0M<`(ZdAfx(LZ9P2Cb~+cmZI;>6Ojhdw?dASkm_s}YAcO3Rf-OiJ5G+b8wFJ9 z`I;XwwbmWJi}5f10dH4n53s_d2k(R^@ux2?g(Dr5=k13-_`MU~S6F{S{g^@Vr9Syc zJA$oZ@Khf@S_+5ImuHwiR1I%u`UYhu=~Jz}1a)CKy=BRi)(4HXL!?tEhLULX*|pE7 z7z=y39lQiN!1qX_N>S_VsW%)zm_$*f8eqQqM%+^^(~jvG5Vk0yG+)p81%3A7?HVLg zb`j%nU-!#Ia!^Q;adwun#7!yMIEsCGQ>v(w-wPDEylpsC-oLBF<>Jp^c(UIkzBkk9 zrt!L6)uihWQ$q*C<(VlN%=spM-=-*&02LkfB?I-uU+2V@dN`+tj2;o3 z`6|>l_IB!?))#1<$u4|drFC2_S|_*(2!2rbc&&1KGw-=&u|^x@Y?z^K{qE&oo&WN? zRr2-NzU4l?p_MMT#A90kJG9NhpE)X~H;h-xhBGn>P0^Njc>5kI=He+R1EpEM12TP1 zoD0l-r|%ZLOY3Yad%tT@395U}Fj$5dVre!;uo^IS0zVUg7)}9YK0Sl&g0nT)KfjAx zZFOvR*{5cM z9cPagM9qYv0#Fsv*@l(oq6u|ZSrA!9p_+6Jqc4|p#Nw+soruv=7$zE8yVD44@Pb~8 z!d7_6dAwk+*5Z8^{_Y(|WL` zR-S6woo@MoTigu>lH0q!#A&ILYnrYafxTQBb$*tM?1X2jSATvNd&*&G!pR9U_Kg=G z#{6kFyQ?{F_xb$YYlo?aR;&nQA)VqMPN<-u5Oh`r8^0*RBFToS4uuVi%Y%c{>PMh{ zZj=4+N|$$rQ>{l{Ju_P-elP$&p~;v-CR4?)QjO|iRI1l_t?}GlQLq*vh85334S-uqwmU8}dm}-=b-pgD2HNmGdYTdGL zU-3BG=-Vvwoa&Ll((>2(9Z6>-Ugv&Hc`KchKXEyy*k8fFa#fgYpAO?Hp+xw2zwtot zUwsWJM+$ zOu^W~=Gv3aX=xL3kC7ku&WGSZENC(BEN3%Y?g$;`iQBbLnIpC>IfQl81QbAGd0K*L zkdOlJubv3lA(@$)AN*STcMZQ+@Gvux)$+jID1Z#HXKosOW&IdHzjT9!S||~Ll4xDZ zC#yinXh%LbYOK{Hokfb$a8vU~xWSlEqhMfLd}#kMUN1Qr%mx1hUkF4Je1dJWQnjzw zj1YdQ_X^q$Ob%_|RA2ti+&=BOTk|oL5*3~f2^W`0fH8yJT3Z75yTK4)1vtYShnpev zdBuTn_$@C{fpCHl=bj89K^-X)?%99?=9PJ*(Ka~bA3F9X4|cDvzW$|bC2HAHzvE<% z=61C^uYPD6=if$kdjxY{nvJd<`Q!#)Ie&&^iS@Zo%jegzES*H5@-6tS_a_U_j~j>% z0+l|;2l0-%Y#mAaLQk9VjGH+t?#84#scH#Ty-$`UI^JnPcFnet2*loq@GRnJXP|hB zEqAa;tN3u7P$ds~v5@voW^zh~A#keNSmjY?B(Afp6{eI?Kn$&>ZV+F3|L8S8;c4ve zTGA#S%KylI@fb1}xRQ=yF5A}&btT}_O>v$_0gk*k_I!4%n2n)zR`K{Z1xZlX_ZY9q zMq3?VD-h9PM9;1r&9)t^a?qoG2V{FC)>yPtc)#&Jq)B5KBIO^m;0kW9pAjFr&qmj; zHZ$e2Pf?&54nrXy0^4kZr^M(qhG#*qxOf?P){C2(_(FTkp56GZSsWWv0YGJa@bFQvh?G zB5o6bDm{iip8rby$)1ypF!V}U1!Q9Vbz|eNleOn2DNe>gYEJ=y(!8<)WPF99QkgLQ7z@lxWk9d!cNa-u~U^jz2fMb?}*G zW*-@qCdjCW&i7WywD*@z8F)IEKGIz#GG~VP4k>xqDkq<3ZW&6! zEiYuw3!K9FS1CQIh?$Z=e0YzrU}C0%H5KI-B-t`Qn(?um^A|HTCbYc1YW z1h_V543^&r_hpkfd^EOAlI7ZqE(uPp595ixzFatYUhdtq z-rH&`j7qC`9KHxtbqPB#y%@Zw3BgqSiAxu*1;D4m3xeqW?y%4;@oTsr%sWGtBA(w3 zYTkLT23t=?+mtMMu=6vRM#rCDI=lreX zar@iw9`iris}8+BB~KBJ>Ed)99mV9S@0SO=?qw&fh1|#T)rJ~E4-a@zrKO(iK+bwG+VSkhk@R)K z9mQ3umN5%STg|_!BE*}KdL2M+O9d?XRd?mJ%* zw#7kZGLPEQe$!}YxVG|BLPrTVTyG(p)jis;T7NQ)ChaD3PwVF&!vZ&-uJ8Jal-(!p zJafz<=U(j6M_0~t*Y&;OHF?OMQ@whRBOyIz<19bILufwp1+ybf>*3WP*IAZG&rSvz zk&VU8gW{q;bP>$}oKq*@W1QcD%$w~u!eCgbAKi3?vGZt6d*u79#d(~0H1Z9^w&|>L z2B4{sf)yWAi8jSDLJOP zk$4D=LBFZ3dEi2DA}*3DW(XaQ2#QS|^~L^v4Dm&S2w~`*4KyF<#Ta`FqBY12TF`D^ zBpiJj*(lWbX1!R$InsB4s0u)L{JgqObrS&O*?Z=QZ4dKx-Bb7Z*i{AoGP5jkUurq% zs9Srt;n_DWJx2;B)R%9-*gdsq_G$>bt(R?-jfL0t-5bN(BKps3G*VeQjjnT?#F{9g zX#mmBr{X^TZ&l30#5q^g5Pq?m8@(x_B#PCrK#(iMw*Htn1uC5WlLZ>P;!0a6|M4+}yrlSZCxWUxo?S&hm1TNTapjc)p%X4V+)ud^iX$)Bi@ zE${?0)68qF_;4Vap|wjc(`U6WNIjyhM)s{$<<=_icj{iWsJE#Qgf(|Ozl^6}`}MqZ z_-BSKyNr}~$=Y*W>t^rj=C;q&86i*Or2PAn$3Hnw;2=q7I3Z~6xTcv~55o4^VRuPi zJoD){I5NzG$}MqA^d4)DdjN%n$i%a1xM;09@R^jM|MQrt zGu5CXD82RIaRy#=#BAt2(AXfuC)crUviJF9L<;7+tV-Y|e#VP~bvcP;c#q7k@*87`h^7A(zs}&RwZo3g>KrF?UnnMsr$b2cS+0}1+$3OO}ohl`;CMwwN ztb_kcCiLGHUEv`=IEPG}>t(>BOoR2V?H9w22^{)V1%~ z!;b`#U8pZCxD2%IEMOs=pRtwRbTL1bMWfVKB4Jkn4w42$#foEx zlr$H2N&j(xJXM@H8NU>Jr?o$M)0sh}i4la{HVa~3^$iE}N&G#$}h7402 z#Nw z)(cZurkQiPmiKu-+&uZc?!B@y^;8}x?dLm-O2Z8e_{bBwdOZGG za!`n<+q{>Icv*XbZUjil6B@|CvZ2&p+~yI#vM2f*z3QH;?Yq9@W68j*d{)(o;OQ*# z&qxR&MPvF4r&x2*VKsC$0i1M=c$+92*F&91?a!gd_$ZE6(M-RQ-stbzJ|}Gap=W*| zmr5r>ZnrRP7R0bH%@7w(3wLT#{C`>iEx%m$c1%P5Gbflfiqovu3i{=^X8uP>9PG}G z*+iORdvzppz%{<>yBRf>92>!D%3|5IwisZK3Q>jbSU6=4b9d4%1F6E8Qeq=ASXIMG2cn=@b&r)zu8ja5nZOe9Yh+O6~DxnviwGvR2I3Sn>B$ zz)1}2ZRflrnv1Q6hHEC%KHN-s*5H#vfn}y2Q>~;!8bJ8U!#K-)0n~r#0okQMgCmfc z=-h>F%H%y$D8DFJzL&Y{Psrro_i%6~17-m+$UfZP_^infX!?uuvBW4Rqz<1AD~n|P zD^dOXd`f2PX1iEI8Pp`Q*@!#@y{gnpwE5)K!{n<4S?j@U-9ZY0Z|;vdZ!fa{a#{}} zOjx&c{YO^8-OHjB{=s$A{wB|^6C*`6zMU=1#}=C@FF}Yd0wKXF*ZQ9Oc}4n_zbZn% zn{^p_ySc%U1j|4f;WV)?+;-o3aR%zsk$vC`_!Ohq8aLWrwtR1oC(+_Kuwv#Jgx&Rp z=&3bySR>{nrPQ+s&)y$-z9>pvcp|6fI;j*JGozewink{DAJkja`aW zh53&Uq31ZN-pYPbFv331wRR^1a$qUy{$l7(oaH?9Rs%O9<(?YpX*2ha0W|7sP4wSB zsoGO|&ld@(g4t_I4OVOCigL$HBwK%rqgLkdg97z(%{AXL_JyM2)`2eTqni?xi4uWM z(VkUccJ4P}E5L<<1aV-2R`M)xL?>D3T+BQL_=mXgLIA24fvp+vy-5N_Bj@U9393`x zwrZo6;s*EUX<<9e`_+#4D+Hq&h8u-2zw{D{XSe(Yj|K~tY#R4k|!e2r9I(ZmhXfj&1T)@;2>1=b%OTe3*X6tDvTkynSlqjgjT)tPZAYgz zJ?}z1JiZi!T$;3v!rzDc!)I+Z&};8*H*C%!_N&9FX~ox%VvYE}8{n)^kNbT~zlH?s z$@UwwZJsROxEy8$m0$3QDvpK6UT|I2Ol^~`HNOV(f4y|(jH)33qfy-b^mMQ zniF*;iS@FZnQO%dD{Co9%AiVVn1p}k(8C^+_F{qj1ll3rXSpnc_>c!JKKwU!&*{@9OJ%9z99D5HuK zgTp6{RY3t4Q^-c0ZR?a0Ns7W)*vFg`q zBvB*lo$2k_K;5&B`gEY*JX(2_(hIlii}4UXM4bf(uEjQ62-&!XW$2!g9TTd#VXj4e z{2-A{wn-eGY#JyUj`@<@6cCO=e=hZRMYd^x0a@^YOKtBTH+e_KKXFA*g0K~SVli5X zh*Fg%2E-09pzM$y3wkGlrUv-EU~Uw8r&EX)=oiJbG(eW=0S5XA(&X|D!iU5|=@mi> zzvg|Nj%877453HGxTh?H? ze2R@@GI(Ln?3U~PM{ssr!;hE;Ya+Np+=-X01xM#FR(tiiJFE+hy*qq=$G5v9f$MPb zG2+zw*ui52laj(0*f?SI!5ur$NUNwxl5~FTwFFK}bFWLwc zv-Gnrh%B{JO+Z}Xwza{Dj0*z&PFj+oRVVlb(}B5F{yqu^4pfweLJfgy27zo?EJ7;z z1|78Yg0!AP!PxCBR=0Qw!!M z$ZH+Q@j8=dj64{$-uU9@=IET0Ul;ZJNr<1p?4LcgX#2{`8~h?-V^g;lnYLO^u&sP# zqI2iE@LS0-HC;+9MLy`@gol&(D+&W)YA z+86LO2eWM@E^BOuh}Oo!w)4}0Uh_FT0BTso9h}G&p`9o@W5S$LZ$LbMju>ud%7phy zpTMBTzucn6nEFkF$}6P`g%-V-x=gjYZoW}CBszZpFblXKLs=%w6eamh zJyy^`?0mJ(d&)Z4lot}h8j^M{he}7hp;%vCYS%}w})kSKGm4O2x zgl4hDJK%9OGR4VV(tMiY?`*9}#9`}2F=Z%>H2(g*Lc1|zNl!nhI%NLJF>7hM=j@u)jjKJ_Yxu!o%U18&#on!69>Kbk{_Gw z@WYATj_1Y<37}@tj4&3SWNst$PzphWS{GvlX zrK?j_vq(;JCnec26O~97$;pBeatDO2_rwcS1-qu17%CrMcz_RL0j&ZJ->6RGUrQ5< zpw|dh{j_s16WFC-J&SlYwg+gkF@{7^Hu;~Y%Jg}C24H|7Hlm&L#vxn)1|Op=#!MH%K4WKxLn1waAOH|Df_h6riGhqi7$r{ltda=RS7_SZ-_2H4i^^=PCaL$ z_Gu0o5nv#;L_u;5uTPcLTi=7Vwa^#kApT>PaN;o0HSn>}UuMV(W!L>mtQRQF@U|*pp=M$>u;Bv(VvVfDF2H} zY-PEz1kt6SQdYjQe#M6*7ienw9aK3V*F(fHzZ`R-#f>!bziRBbiX>tn(T|(&;88gx~aVfn=p>C{DEaGpOJRLr}gK;Wa7Bi@h_#7BNQtc?I- z*3cPksBVjR9{lz=pv{W9{Y8)3)eM_Y=M0;HCuvj#W*c*y?4nHt{5YJ=QPCoSV{iNX zwOLR%2PsGQ+L^#WzjGorkDoCr4weCCMG`4UaBE-gFzqNHn{&;CE8S_&^)S zx2TC-6x2>OYQz=A@JQ<^vvAiYrzMJg92?DuWlzVokKK=P|dltk>!MuzqhKzQ|v@xh9*7rtka9=Yz#uNZo}P~d(IHz!{m3;Yz)>SP+y z%EQTDSbi1&=>cIWP!6_&(7X&d37dw-%TUKvqI14~YdMUWV#%87_Wwg4x)9znj`4a0 z{k~p!(!UTWApIc5UPs6o{MCAqMZUw7AxF$t`;?$B15@fYm2x{5=v)3YLb#eq;HBA; z_0L#_Z_9tw2i{3yg=XkEXnYJ6foKKoPV)=8N8-q&AYP+SB3 zG{4*rAUR8hIGd$99Y6;_5FZBUyMEO3OlCDi%6Ty{@2#;e@H zr|N)%_?M4}vIh@PbhhxY+C*mU5A@${V8>Jvoe`i8C~DK_9U5Hs5EAgsfd|P7gyzF} zv8`H}0l6fB%?56HEb=bOBpTj(BR;U$6f%ME$8|SCNF%r^h6T%&5v%~nsF{tm=`WUy zs*o&c0di&yX#Ql|-9753#QvzOno4Y$m3WBv5Ws(^t2qO@ zK$W^zzg_QdpxgHfLf97$;k&!;)zNw-$N@TSD{^oFKzrtuq_np-+O%1qT6(u_${aL! z@||n_&!u!mlc)2HLEab=i08f>tKfuklu%izN<{%g7Qf=PksCEr(>bnBnUx?-H@bhj zAz%gK#Yqv=cc~7`{6~eAtZss^?a|78CDO{>sFfRa5SaOrxgAGv?xL2mb{!HJd;xNj z_-HZMeavE^&MM~uz3*cFUNC6XK$3q2k}^uOG>x0O7oM)bUq4yPV{SL;nPjJ1uJDBk z?J<*@l5KgHny=-TgY6EJl=V-0*D6NNQUr%-lZkG|?umTjogj)h|5>wqR2OblgrvfqDO`(5O>m^HGfH`Q5tys0W|`Rxk-ctPjjgM2H|P45OPQ zr`MzvtI%G>F*zDzf^+AqfJV)tt74v3r!b)K!=DhPvPeoKFABdf%^bFDd#df)V@Ggd zMC%t|cTo>Zm>yo^8py4PfY-Om)5;!<^ET~5Yo03(-`X)abbXbEF3`bT`w_g-@^o6Y zRjI1xqXRKD&0#Kp_S}nyfhEfL`4`f7QO0e*N2^U@Q)B~P+E zsc4Z(BhBirN1@mgzvt`K0*0MPbZn*??R(9akdRk6u}?iGsm)tq&}Da z+6~Cgdf<0r4Kq>I%A3U25vVp?Zv!T4JO{BaIBf}{fjo$el!<4qmJ29dRg_z3E0%*o zT5I5Tmfy7k`hpE}zFtNNWWL(ypKeIoR#<`?=ABKp#T{!AFJ-!l#~(g=;mb(6Bz}C= z51wfosPP+)#F!fK*G6|#?d*wyBg1Y@CjaHJ@;nL~WP>@KYVS#gc|j`nOuRwAU>&@# zI@1Tt+<=9NiNe$1>)cDc4|!`WhJ@K~-h?W(abiuxcIOs3rS4Y>FFsjiJXet^ZqH}9 z!o#YiQ)~5)a3tEB4SH|XSl+{n?9w3{)9-z6k3w<>KQc)!@-7~K;>B>hqUeK0M3HZM zW}s1~LIh290w-V2S;3i?mX~z&C1e z^e8fVz7PWHMKxs&^Ir!4P(^lf9;v{s^!&`mnF#-NSWKtLnv*`DrL1Xh*J!*fFv3XE zJ3t306U{DawA{kaiNGd|;m;h2<1YQV*1by;)IOvVq=o%@UhraNI)1t5q$#^*7)n|F zN*zLn3^krrID>J{h>oQ_a<}bTB>rxc%G7_~yNbYneEKoFC>o zC0u!jPIu_dQ64ecme>F2?Yz0zZA4qwwAKvk84~fU;aXBY1nDkw!sw%KPU7INLpR6W z`KTLjqMgRHkJS#)6qH&%X5kS&HrSi%6@m2Z5f z7ENiEI#Mzq@$xT>(|SUH50crVz4?xm0k6J)qpb59;IYI62EoJYzvV+qLR~pYQ>X}b z7FI-^t+&O)PynV64g*^wXaI9$DwZL{w0n2LZIvzNAIfdkfpO$ExbH(%jo66b5) zeXDx!atqx{jQMe~A_||97{B3HzXr;Az0s4(ii@o#pU<_nr%IFw&9+AH+g3w=sy7z# zLYsf>`nAx0j3f*G`P+wy#ueG9{~D;=hBBmoC<(F?U8a7D@VTrEX9QAA1z#0~HblWd zQixWAZ-ZyEqQ;R;l#xh&Yjrh4xh#L~m_flZaOE^Y_$4$C!9|&|M`i0r_^c)QG6g~* zi_;EX)(<`mn5$0E#P_G_61bngx?|jzK3Y0h9Rv}8$a&fV9e{+(}<%?EsrL*-3AH30jQuzpmXox-_*3g zeVw%_%*npfRHepMZ3z?A`8PbrPg=!Kk9h&Vs<;4FHJw=Z+%XZJ@v5x_rj2^5wEI2L z{0xc+GaM6DX!Wo8OGqky9zQ8i{bx@2duTvseIoBJrD_;W&s#I;N|JFv z<#&DIGuNQFMI+3%)2COH@uAg7@pBUE4@1b`B_Y%}UVou==|zflCjFMIuQf6303U!V z`?hJo*Ma>9Hk8ivhgy1ks^(2Tzy)KX<)E2*p^0u0LOTK8MX*8yBm8FE#elOHcqstb z{Q;rykSjbi;wN>M2-wBYA72SHbMQQFj0X4AGE+=ml{GGaPWXx<|GNUcV*2@$!t6ltwa#vzbEeRf;N8u7#Avn-W!y8g`k@ni~0m z4h&ZY1m(S1cQ)jw6@1wEE}oLophkoQA1;&?8PA+kj0XJJ(ERw@-2Ind`>^FlRS~p) zyurRUp#5g0?h+;8gEDK5)e^p}Sy%)-$~vR^`>d>(kIA9QNIpWy z>FaXg<3Q(zlp^O(e7mt#dRuEqiW!phSLZ{)8!ov?rM=@lx6Uw$Q#Z3+7AQgQ0g4&b zH#e{L<#!rG=(MTm0j@2U2EY`aes~Tuuhc_3GW3Nqs0bLD6U}>C5o)#ra2E&Uh{);? zZJQnR2EeUg=}xPd3-0S4VKaqf3;+4r zNRw!*nmZ=ZJaeVxMwc$R3SzlyE3K3Y1Ep{Y)C>z|)Mz5%ig`%258o(L|MC5G{7Zeu z^AD^C2ZUFjuP#Um`WIl}+R7GgbV7e$!p4Rc(WN*9NaoyJwg9Pg?l@rC>EK)I@%&No4-<5O7N<^lV#*A$2WIF{ zi(cxW2Ccl1vXBfImugrSC3|U*3V|o;6DwQx($6lLUL&q*Tj|-9=HF^`>yJ&xeq7OC zo~3d9bv!gz)1BSP;4WOOa=_ab7)m&CA4C8eoo+_t;^}g zeU!$3pi)N?6S~=)ptM$&5kO%ehN;E}qG>e%*$50z2{h6PRW&v3O_mP3RVQeGF*>DE z3u0Ieo$ii&#+;f3ACdRjCy3y;@e)or8!!q=whx|4V|6Gp|2XN8YFwTym*65!5yZJ8 z-4jOs#vy61g+sb0fWAf5wRvatZ1*-Q7^L8GEm!2?><3*yi#vg4K}lt>7CsoR&okZxAN)+UfJZ zjBbmYX#3>`4{IdUN8sWu9Z$iz#3S8Ii^ zBgW(g+KNVdIj(nc+*M6>#Os|d>bfQu`$ikG1-NF+Z9E&{P?4a56I<$UC7ctWjHm#( zDoc*Heb4DCdrV6}M5=;v{dyyS-t_hF5Cntrb0X_AG@yqCXNv2%&L(iZY{XbrK1PlW6 z(0phQ+!PA>ucpzenOb3{=eq&+$B8CH2QMWp1awQ?yWPq!0w{7K@55A<7i| zyg7b9|KNmr_~jCMGCnvl6BvWgs)a24=cIy|-ZDW2b&`n(eWl~ysM{t|e=#X@L;PTP zaA$P@RlKM{Be_$jHjN3$f8m%*j&>Gb3^wf5DkF%AF&AHPZ_TkG0nk>97R`Rb^OR@# zv77vH^Co=}!R3QQU&9eNph?B~DAlU;AG?qw7IXtd4xHAW9;z)96h1*6zLShnU^|tx z(geyJq%KA5*OQ;P{$YgBUXfDmr{9U>gHGBTVu3f#eWsReucb#tH1;JBwrd4z2Ja@LGqr&@g zN!vA$h+z7+yycpRe_YtR8D940;1xaUyezleymG~-xZLvnGh5XfDL zO_H2?VmY&O&`z`!0$Vk}k+Dc%_h7<+j6kuy7}P5k+)IVOukxXvL6Jy9?d#(`eZwP7 zSaaTT|7XmXwlgrW#DUHQ;Bn<;Py=?Vimj5HOYOgZk2XJN?UAzdf9376izbd;%L*2) zJh$Zf{O`w>*iJjKYr+<@$@v>DZHoF_aIoM7)8fkGDqH8k`r=A4{gnK4`D{zRjrY+#g!{rPKN39yamRhU%MP|;}^x^-ezUt5C*4)0o$N05V zH|O%8$q|OL%l~|1XHYK<1J-%-p8qYXjy=4u&;D=wkNlfO9Md){Rew4+f5KC-y>%xR zdHx0N5Zc0dI_vCaSBI+Wu8zJ%4J}bdNIOhgUS_ zZ3AxaUB@FN_AkCy+f*0z3nF0!KfntNB135fEZ8j!Y!}>3id^=4 z=l)%vuF2LNELyl`(%lO7Bagnu-Tr4&_i#zw{FApPY!TI3DB`f>tcF_Jo(cJv^?JWn zRfZj(+jHWY8q>V*)2yC8WA9vIy3hFS#sJ`c*7b`GOPi0jq>JC1SL*utKF{v^=3eQ0 z{+(}J@FK`|XY}27*BYE{!@qth&8(0}SiN9d_W5mPeCi+ALVx7HJ5dPS%hg>oJ4{M7 zv{T`MWaz<^{LG5>k2bR(rx=JE8(eu{#W=MfW#;j(lh>4Xs(hVsD|6?+&Q+&n?wRtg1UzoMQHy4>P~!8ne83=sNjV-n8{w zmh4@y*~qa`;?w~~!Ma5T1)%{CP5~#kH5>dTwlFe?Jle~e%wa4qaCq%MZP&>zmuB6c zTXy0?(tPFvbKbJPndBCJe(%!L!MFe0eNw#7_)o#^f$YUHao}!l-sxvK=fqXsbYz^o zf%}nmGZ%lx+TYu@9qo*$NlvPtx$1h=bN`$9-8q$>)hg$oS+JdXU)ZwZrKiM!>={=S zr_Q^{yJ?&1q`MIob3@;}K6oIXp~_5blCGM~`2&|0^wmmkFkI|DIk*V8$|r8x*`3DA zL*vROS_@`ADwGLKa!LF#w_5Mit{Ah_60UcfqPCSho%j0N zhqpU5WBp$l;Hd#=H}80JI%>U5*oQPyNJ&gU7W&NLj4yW)QT@No);4vi(p00i_>zopr01C**NB{r; literal 0 HcmV?d00001 diff --git a/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg new file mode 100644 index 0000000000..efc3589c07 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py new file mode 100644 index 0000000000..8f7c859e38 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py @@ -0,0 +1,15 @@ +from core.model_runtime.errors.invoke import InvokeError + + +class _CommonVertexAi: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + pass diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml new file mode 100644 index 0000000000..da3bc8a64a --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml @@ -0,0 +1,38 @@ +model: gemini-1.0-pro-vision-001 +label: + en_US: Gemini 1.0 Pro Vision +model_type: llm +features: + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 2048 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml new file mode 100644 index 0000000000..029fab718c --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml @@ -0,0 +1,38 @@ +model: gemini-1.0-pro-002 +label: + en_US: Gemini 1.0 Pro +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32760 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml new file mode 100644 index 0000000000..72b8410aa1 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml @@ -0,0 +1,38 @@ +model: gemini-1.5-flash-preview-0514 +label: + en_US: Gemini 1.5 Flash +model_type: llm +features: + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml new file mode 100644 index 0000000000..141f61aad6 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-preview-0514 +label: + en_US: Gemini 1.5 Pro +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py new file mode 100644 index 0000000000..5e3905af98 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -0,0 +1,438 @@ +import base64 +import json +import logging +from collections.abc import Generator +from typing import Optional, Union + +import google.api_core.exceptions as exceptions +import vertexai.generative_models as glm +from google.cloud import aiplatform +from google.oauth2 import service_account +from vertexai.generative_models import HarmBlockThreshold, HarmCategory + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + +GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + +class VertexAiLargeLanguageModel(LargeLanguageModel): + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # invoke model + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return:md = gml.GenerativeModel(model) + """ + prompt = self._convert_messages_to_prompt(prompt_messages) + + return self._get_num_tokens_by_gpt2(prompt) + + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: + """ + Format a list of messages into a full prompt for the Google model + + :param messages: List of PromptMessage to combine. + :return: Combined string with necessary human_prompt and ai_prompt tags. + """ + messages = messages.copy() # don't mutate the original list + + text = "".join( + self._convert_one_message_to_text(message) + for message in messages + ) + + return text.rstrip() + + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: + """ + Convert tool messages to glm tools + + :param tools: tool messages + :return: glm tools + """ + return glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name=tool.name, + parameters=glm.Schema( + type=glm.Type.OBJECT, + properties={ + key: { + 'type_': value.get('type', 'string').upper(), + 'description': value.get('description', ''), + 'enum': value.get('enum', []) + } for key, value in tool.parameters.get('properties', {}).items() + }, + required=tool.parameters.get('required', []) + ), + ) for tool in tools + ] + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + + try: + ping_message = SystemPromptMessage(content="ping") + self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) + + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + + def _generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: credentials kwargs + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + config_kwargs = model_parameters.copy() + config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + + if stop: + config_kwargs["stop_sequences"] = stop + + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + + history = [] + system_instruction = GEMINI_BLOCK_MODE_PROMPT + # hack for gemini-pro-vision, which currently does not support multi-turn chat + if model == "gemini-1.0-pro-vision-001": + last_msg = prompt_messages[-1] + content = self._format_message_to_glm_content(last_msg) + history.append(content) + else: + for msg in prompt_messages: + if isinstance(msg, SystemPromptMessage): + system_instruction = msg.content + else: + content = self._format_message_to_glm_content(msg) + if history and history[-1].role == content.role: + history[-1].parts.extend(content.parts) + else: + history.append(content) + + safety_settings={ + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + + google_model = glm.GenerativeModel( + model_name=model, + system_instruction=system_instruction + ) + + response = google_model.generate_content( + contents=history, + generation_config=glm.GenerationConfig( + **config_kwargs + ), + stream=stream, + safety_settings=safety_settings, + tools=self._convert_tools_to_glm_tool(tools) if tools else None + ) + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, + prompt_messages: list[PromptMessage]) -> LLMResult: + """ + Handle llm response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response + """ + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=response.candidates[0].content.parts[0].text + ) + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + ) + + return result + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator result + """ + index = -1 + for chunk in response: + for part in chunk.candidates[0].content.parts: + assistant_prompt_message = AssistantPromptMessage( + content='' + ) + + if part.text: + assistant_prompt_message.content += part.text + + if part.function_call: + assistant_prompt_message.tool_calls = [ + AssistantPromptMessage.ToolCall( + id=part.function_call.name, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=part.function_call.name, + arguments=json.dumps({ + key: value + for key, value in part.function_call.args.items() + }) + ) + ) + ] + + index += 1 + + if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + # transform assistant message to prompt message + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message + ) + ) + else: + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=chunk.candidates[0].finish_reason, + usage=usage + ) + ) + + def _convert_one_message_to_text(self, message: PromptMessage) -> str: + """ + Convert a single message to a string. + + :param message: PromptMessage to convert. + :return: String representation of the message. + """ + human_prompt = "\n\nuser:" + ai_prompt = "\n\nmodel:" + + content = message.content + if isinstance(content, list): + content = "".join( + c.data for c in content if c.type != PromptMessageContentType.IMAGE + ) + + if isinstance(message, UserPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, AssistantPromptMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{human_prompt} {content}" + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: + """ + Format a single message into glm.Content for Google API + + :param message: one PromptMessage + :return: glm Content representation of message + """ + if isinstance(message, UserPromptMessage): + glm_content = glm.Content(role="user", parts=[]) + + if (isinstance(message.content, str)): + glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) + else: + parts = [] + for c in message.content: + if c.type == PromptMessageContentType.TEXT: + parts.append(glm.Part.from_text(c.data)) + else: + metadata, data = c.data.split(',', 1) + mime_type = metadata.split(';', 1)[0].split(':')[1] + blob = {"inline_data":{"mime_type":mime_type,"data":data}} + parts.append(blob) + + glm_content = glm.Content(role="user", parts=[parts]) + return glm_content + elif isinstance(message, AssistantPromptMessage): + if message.content: + glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) + if message.tool_calls: + glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ))]) + return glm_content + elif isinstance(message, ToolPromptMessage): + glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( + name=message.name, + response={ + "response": message.content + } + ))]) + return glm_content + else: + raise ValueError(f"Got unknown type {message}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller + The value is the md = gml.GenerativeModel(model)error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke emd = gml.GenerativeModel(model)rror mapping + """ + return { + InvokeConnectionError: [ + exceptions.RetryError + ], + InvokeServerUnavailableError: [ + exceptions.ServiceUnavailable, + exceptions.InternalServerError, + exceptions.BadGateway, + exceptions.GatewayTimeout, + exceptions.DeadlineExceeded + ], + InvokeRateLimitError: [ + exceptions.ResourceExhausted, + exceptions.TooManyRequests + ], + InvokeAuthorizationError: [ + exceptions.Unauthenticated, + exceptions.PermissionDenied, + exceptions.Unauthenticated, + exceptions.Forbidden + ], + InvokeBadRequestError: [ + exceptions.BadRequest, + exceptions.InvalidArgument, + exceptions.FailedPrecondition, + exceptions.OutOfRange, + exceptions.NotFound, + exceptions.MethodNotAllowed, + exceptions.Conflict, + exceptions.AlreadyExists, + exceptions.Aborted, + exceptions.LengthRequired, + exceptions.PreconditionFailed, + exceptions.RequestRangeNotSatisfiable, + exceptions.Cancelled, + ] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml new file mode 100644 index 0000000000..32db6faf89 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml @@ -0,0 +1,8 @@ +model: text-embedding-004 +model_type: text-embedding +model_properties: + context_size: 2048 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml new file mode 100644 index 0000000000..2ec0eea9f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml @@ -0,0 +1,8 @@ +model: text-multilingual-embedding-002 +model_type: text-embedding +model_properties: + context_size: 2048 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py new file mode 100644 index 0000000000..ece63806c3 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -0,0 +1,193 @@ +import base64 +import json +import time +from decimal import Decimal +from typing import Optional + +import tiktoken +from google.cloud import aiplatform +from google.oauth2 import service_account +from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi + + +class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): + """ + Model class for Vertex AI text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: embeddings result + """ + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + + client = VertexTextEmbeddingModel.from_pretrained(model) + + + + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + client=client, + texts=texts + ) + + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=embedding_used_tokens + ) + + return TextEmbeddingResult( + embeddings=embeddings_batch, + usage=usage, + model=model + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + enc = tiktoken.get_encoding("cl100k_base") + + total_num_tokens = 0 + for text in texts: + # calculate the number of tokens in the encoded text + tokenized_text = enc.encode(text) + total_num_tokens += len(tokenized_text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) + project_id = credentials["vertex_project_id"] + location = credentials["vertex_location"] + aiplatform.init(credentials=service_accountSA, project=project_id, location=location) + + client = VertexTextEmbeddingModel.from_pretrained(model) + + # call embedding model + self._embedding_invoke( + model=model, + client=client, + texts=['ping'] + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + """ + Invoke embedding model + + :param model: model name + :param client: model client + :param texts: texts to embed + :return: embeddings and used tokens + """ + response = client.get_embeddings(texts) + + embeddings = [] + token_usage = 0 + + for i in range(len(response)): + embeddings.append(response[i].values) + token_usage += int(response[i].statistics.token_count) + + return embeddings, token_usage + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.MAX_CHUNKS: 1, + }, + parameter_rules=[], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ) + ) + + return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py new file mode 100644 index 0000000000..3cbfb088d1 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -0,0 +1,31 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VertexAiProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `gemini-1.0-pro-002` model for validate, + model_instance.validate_credentials( + model='gemini-1.0-pro-002', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml new file mode 100644 index 0000000000..8b7f216b55 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml @@ -0,0 +1,43 @@ +provider: vertex_ai +label: + en_US: Vertex AI | Google Cloud Platform +description: + en_US: Vertex AI in Google Cloud Platform. +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#FCFDFF" +help: + title: + en_US: Get your Access Details from Google + url: + en_US: https://cloud.google.com/vertex-ai/ +supported_model_types: + - llm + - text-embedding +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: vertex_project_id + label: + en_US: Project ID + type: text-input + required: true + placeholder: + en_US: Enter your Google Cloud Project ID + - variable: vertex_location + label: + en_US: Location + type: text-input + required: true + placeholder: + en_US: Enter your Google Cloud Location + - variable: vertex_service_account_key + label: + en_US: Service Account Key + type: secret-input + required: true + placeholder: + en_US: Enter your Google Cloud Service Account Key in base64 format diff --git a/api/requirements.txt b/api/requirements.txt index c9e9c2fa29..306f600afc 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -84,3 +84,4 @@ pgvecto-rs==0.1.4 firecrawl-py==0.0.5 oss2==2.18.5 pgvector==0.2.5 +google-cloud-aiplatform==1.49.0