From 3bf84309c6a742ec3fd4ca796ddd468a960ace7d Mon Sep 17 00:00:00 2001 From: Toshiaki Takeuchi Date: Thu, 18 Jan 2024 17:28:14 -0500 Subject: [PATCH] minor bugfixes - but the test still fails --- examples/ExampleParallelFunctionCalls.mlx | Bin 6313 -> 6558 bytes openAIMessages.m | 9 ++-- tests/topenAIMessages.m | 59 +++++++++++++--------- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/examples/ExampleParallelFunctionCalls.mlx b/examples/ExampleParallelFunctionCalls.mlx index c8d893a08f32efabfbb1b44434a16e8dcaae1979..5565a1a0f46b1f2dc99ad5f0abb4593638950f59 100644 GIT binary patch delta 4432 zcmV-W5wGs4F`hH9*#Zg~*~&Uj2mk=^CzIm>FMn%qo-qn z!X81JSjiqg3=TL08LB9j!yf;1eYXE2e{=GB(v4J@rywoE@zQiBJwCEF>vlRO96>5f zpno!ucyFL|DlC50!_GwMDAOtgV-PM)I!6ZwKXg(NOHPpYDw@(~=H4!~`?#{x1lXke zD3Ttxu}y$?=*Bp#E+#OrbgffQa&PwJ6RpDfBVpfPUNa=ZGNDD1K*9#O3@z@Eh4|gj zRO77=I=tgC~(@2&e}5uhjy%Ak#=#!3J8TESJXW zJhZunh|zK^cY9gGKVTc~0QCtg*}neMP5?)1zh)@PVb88kpO@BT5krA>X=zxuJyrJPUC>k_EO7231+mz{^VN#?( z<4zNJVt)ssJOuXL5q`?3q=_p(CD7z}d%PpF+{irM!XTD~%hld(F?LlEqG_qiDh#%6 zDFd_y<^21sq~r0DN@X5zkrDS9FWt#BC@xwUZy&b^5yQbH2C)<@Xm5aXNq=RCVW%1N zI7U5kW1mG(giM!&YOHeXNBAX=F*4RbJq(txZh;}l2=sH8C`YdYLvyN;WS-tT;MGVp zbRA4#!_TBt4PLkHmteeo4Z;Gtc8a;3d~SliyUI?TBVq=O%M7=hC|8HGv5KWvd_!H7 zUN-s(!Aq5e2KgRV6wFs8r9b;LF zf=mMTK_*kT6$z8#8dfZMmMLw^RKKd~K$G2m&~dAy4)`#$`>%rU_pt=JSNn&D2R~rQ zcRQ(cM|GXY|EY2&G>}PQAtJD_G*)0$imnt7B~?$V4xqIuk^%$n^M4Ck?H?|vxP|>| zr4rK{ybJ}t2MM}Rp$3yFXVD+~=tc2h#^=fko0~qa z!g{E+4}aq>FPEbOi}JiHO4Q1BT4-J12Gt-xVcj24E(t(urCS(+UxNJ42V_-Pj-bU9 zc8QfmZXoiRB9BqBk$Fpp6XSh;EF3dlp>tHWKtCn&tD8-Upp-r7>A;Z{&%4gT=VXBJW z28+Vjyy0J4;}k5Om;ZLr>uA7Q7#@*hwe_Zs=mFQgf6Dz$ zx+b2>*e{w8y;dXj&pcWkPEl+jhKK=H8Yl`p@7zJIcivBNG~QAz*)b@Z<7b_C%A ziw9+oT2`m^c_S5B#o@8P{al8t#%2&h5>=VS-tb8m~64NX@GBS8hCbfJo9_I+cO%p%74ZN-)Of>nZ&EwD&Msw z?bbK!UHL4&o;?!o!E96*HFl)AF8UTjPZvudCw~Q$8hU}Y<}(J;88o-l)u8d852e!3 z$0ikfDQZ^SEP9m;>BIfH0=(jJh}Y=*8hEpAiu;ELReR=k&@P^Yl4jP?ne(At(faez zrnKCV#jfZa)pXb*pN-(z8hiwQkZL0N%3Y(M(_gE$sniVD_^%oIuDt9!??rn&n>!G} z6n|mnOHfms+d`Pzg%gxU8H6}xA+%7j8b?q(iZ~JX3EWkaDZEs>UwNpp@@lzXI+uxK zL^$lqJcYz-Y^V}8Ep#3gt6WmN-(UfuKKG%A#O(#NQ#~}L zrXz5gUNc$fq5D;nX#hOHlqS2#PvFL@Ab%T!T+5RJJrVY3f`8L6qQXROXToqs@|00y!eBnnax{7PyCjVT#UM+dgme_B@`|(KAsM*{}Cb{tHk`0|YGq000O8001EX zGIqVk3kLuIvpE0&5dahbZDDk5VUr^ZA%9z4a@#f#e$Od5c!*yVQzS)+QfQ{CQn&8J ziW{j(Cw8U-kt>NYe^d~(<#_zsEA$$9=>^i8^dtpHDW(+IP!d|hPVDiFiNylMx4YQ= z0Lk`~c^D{XKoUQW_K>z)M+!u4?D^4T5B)mae`KL2jt$`mJP*9*DV?P>aqJK%et(K7 zW?g4ENHJx@2NY9qLL8+yu+??B9pH;NrFp}P-86)V+G|NOEip)-C zfln!r#ef-!XEEU~aq7DbTP?J0bxl{jUUNQYmk@lR(2E)DAXyPb1}N45+4{2}B+cE)8*zN$`X#876y#5c>h|d&@?6F?^Q^cHKBc zS33`EHD8TQNgBDCD-NP@?AYfx8RjDJ4t?*LAAvK*K?1fa>RQq7QZD<5>YSh*zRKTUo{@Ms8_a+B@l^~h-KEX9C}Pf+ay`N`qptj4;T7_QKpns5n6W^e$;Z~g8o4Gx-gh=2P>ukfUqZgC0y@b&nh z|C*i}{qEsANb&5<0RNZ%oh;!Ei{+1AsZQfaw&7;z!B=N6J3Ik?SXJ zTvK{o#ZQ!${^9+8jdu^@(~DT~DVJ*X5AS<51xFt5Ie%gcNftdi)9>w~>hZzSDVlj_ zy#IEOQ3zSQVVV;7#A{2BeH$KVwl?s9TDV-)sdnhKE)i__cZt&@s(&ghKP!8C6Z4Rq8QH3!l zK(biqEq_pQthXT1q(~Y&dqo~!axe`! z7*h1Dlo zo{`soL!8Zw&FcR(g1P{eCXlV1%~vMgrjzAV!Vp(1qvVQJMwZv)T*463Z!WGJ8abP2 zh&N!Pf*VjyCmP}n*!bKHXjEcj$r*`+fy=PXC4C->uVT#$Hxpg+U z0e>1f=`h9TbhCozbY0FmOz}b8tl&Xim(vbY%x=vJp44?YSun+9!Q47okYCivxq>O? z3TDY%p@eFq%c(+BoY9(9%xH41&=hC1=4NqL_=N1Mxk4GCDCY_-F;{4n%oV=={@ai< zg_ihoqE)@i2}4d4TH^HGs@~;%O@oJ&|@nrBt;O308i zg|?U}v@6K-8gj1C7ITI6*1y590U9}3Xp8GjyMpzmvB4J^#Pz1V^?DQ@HNDkug2}hB~R<lRpz0lQ0-J12h%@lRpz0lXevk7EntC W1^@s600;mG05}2w08keI0002aeopuR delta 4176 zcmV-W5U=l^GpRAK*#Zg_#e_@n2LJ#~Ba`C-FMn-s+sG0A?q4xo;LZ|`rfuc!dQc*# zjuNN7HjNQG2nxdp*5pW{Oz!e_mzIU%?zeYl`69}eE_TmK4Yc|on&j^6v-A4Q$}e8t zWGP&uv58g#u^;XUKo#jYQR9L5cy_e&lX!XXV%Cp!RAi_e(0FD0vw@g6m-lyfZ8X75 z+J8{zsOa5Dn@l>oHRIixHgRrrgw~R{EZyzx?ftl$$wUbjd8y+$UvvL;V*KUQ%~OO~ z|5~O4;S!gk*yR_~sD7B@$nmq?W|IGM2OoJA7*3>ndvXQ@1afI)nqms0LPd^t2oc>| zp6a{+Q;PWl3cPXm25~k?EZE2-xdS`VGk*(G0p7?gCxLoIe>x)x3n7iAlQ1$mgF;zn zipUiPW8meW@OxQVmLfGhAW(x@c3~q z>FIm2@!_$5DWgai%59}u#;C44BBXE8-I8&fOE>9khv9Y?E-ntf2-RZbPD~5kmw)#& zLCFq!Z=<|frF7QwW<3ERDy-8P8Za19FhHc0Blov3O1wQ~3W+L1hVC|ES9ceN-KyXW z;9M}wVf}t4E&Qw}3SMgrK*F;hcIGGz=ro2*rWnFs$-ODw69R5wN<(9ffAJ5(NN3{t z7SKohEEU3OmbeK^86nSYPy{@+_kR(1alQ2esU`PPZ*6nIVrm zOz{T(jtqH>@O_W2Dz0GC$~P&tF`6BXv~fi?U?>9dj(=DChfgYg;owZ` z)DA{Z!+>v5p(#yh(B@iM+(4+D{}$9i8#L=&ke0{Q@*UQCP)7kpQ8HTKId#$%c1Wwh zfUNfZZ^GB*YINvGpPx<Kmoxc^CRYjgVKw`$MiJAsVC27Ny{)p#S3$dlk_U@|e;u ziCGjD;{Y^Aj8?5mU4L(e$JCnJKv7;DACDgz@H?@wb?$E5|B%x>w7D3T$+7J7s7y1d zuQ#&E{^FSVBD3+bxeQ>&WjIaQXF2EVs0Cy>k4N&2 zG`T{K4cg$M$}ekg19>>7m2wVY@rgo*dKq4|(rXDRQE*B=b!4Q95*z6N&h*t>m%R;< z!oHa3*ZWRm;-$?=Bh#z$*f@CnD)##{55E1D`0*&gG_Ffdu)Uyr+g(rboWCLc-6e#N z#{>U>eh>KsBaXd#8)b5s@Piuj;|KOq@%~V4|4sp~{&+%B*z1=H`ig|JNwczLD)N{j zbr1cm=YPwFk=HUwt)$S<>}TMg0S)_S7PEfNkX&t`&4Q%?UbQVh6PJMj!V&0 zI;wY_BoF&cJ{|acUnO8XH zSw6&BqV3eOssWTB5}P5PoOXC{aNc4%CGEH<_kYVDyvYrVtGD$Xh)#9%?{Sj8rqhO}bfIEt z<{d$EL^lv!(Ot-KX1)ZCSO4IhlJSwZd2y`?q?Tv`@7}CLvYZdN_}<9{mCFQ;{({z; z)qgrH@`;?c_-yg4@0(^-#w-;r{wca-|DJla4qVEBe*ALYKHcaoDthm}=Nr9s2OK2$ zZUO6k#gmMC^f&b~Ae-{coFbNgt`HD@;lKEk|)B zsLlt(H;uEO$hr&B!f|6jk##&)+Tf7-Jb!9OxL;`kphCzO{|^e-sdMfkp|s>J8WIug z@6{cS-$5M;WSVtwminnp!TNLWp|E_{zbQDq1`gcD^C0oCMm`C@QhL_p{pDTXKQH~q zDW_r<$t|8S)cM>b?dumD5oB)Xcf8Jt^a*xck!<=2mc08OBMqJ zEdT%j2mk;8Api$5rtWnH000FrlXeR%e_CB~+cpq>?6 zDe1I!rUQ|XM43Md2-;FTe(e={jlT2(?M-@;E$xMjF0^DzRvHO9F z{dg7z+9jchA4iACIH(|vL~iW)(dZDp?H@mEqQ{O+!U6F-;=NAUG-ZinhlJrLf0$vI zy2L?>87%KH%!m`>D8+%TZ_4ce&*PM3121;dkVMSZ^HMt&+w(!1IFw*-ZtH6@JDCPP zV}vdiOldrgDSwGmIjGrsZfxrty86oIY|jp*c&@SelpG?=pAn!52)E0zx!AeH)W?H> z^cV$y*t!rdR%bg!I3x~^5?dEfe>MkA9BuiB3)|WrnK6o8)=d)@hlKK_d*ASROh>{G z2`F77u8rCndLEA>?MY0~zJ^95@(AV95C=knM_j3D!6ON=AMm*etngy;tq>fzafojA zzOnVp8k^BHaz!YfMZ?&!uW-`OM8rGsy_bGOoFNVpV(YT0HT@3bvLgu2e}!yKej5C> z@!~Orlu+kEz#d_3Ov&&NP_Y0H+G7$-PvdDi)e>0bGNicyPLe~EE#V&PgY(gY5qp$r z@J)2g$J&+8#@f&)ftR#xeen{T##f1B*t+=FcH1*aSgQfwjFNWf0(!=#ZCy{?F$6a`h^Hii$ze>vZ7`%GJ;d}Pf1N&rs87Ms10VR@ zTsJF~Iud(#$Fc7c?G5>m`jmh_-826uI4 z>N(m!M}Bg$xU0c5h{+9lzb`exyJ|i2C&Q;~luXj`GMnA2g<2L>r4iQ^}4;;YxRE3)w&aal}l zeM3ilLhuBPOz$N2NN@@lcmfxAP&m6rPsPc2C5Qy_J#d&fYA|ren4+8g4HcftN~L88 za*B(Cd;{KD$U zQ6$iftxFeWA2l*BerCh3nO}=VvKjl7MWX(gXl_a=?^ss^voWEXAGJm4(H?3&9BIcC zN3Nf^@xIpWw6%u$=kML+d_TUJ$8F81S>lxe^W@Ozbx-?5ANEU|P0VTtF2NFYxq$WjDD;9||CQ2Zt5=cu5#4vV@ zM4eq#e!PMno>g5AdSA%KD!8A%&#u{oYO*xik z>5hL>iN!ReSj>V;SrJLrQiCz&V9e4bnW1j`ro8W)1vmZzvdmC-ep6n{%$?UVb!61t z-;~!fv*21*L}b+czgm&g%S!noqwfDzc{QtUf7kz42r3d(hD5e@H(c9zmqAujiK@I~ zRSWJ|MO1l1%_XXG`pw0qL!)LBHTeOomGA(n=|oL_0BfIn0IQ`qSZc;$$%AL@Ja}p} zrkZkC^5|J5B!i}!b69dVXziR0Zh=NkIxP7;-74Wd-Bhy)wIKsvsZhgO%^OUS+I6a7Stc=)Lg-ma|Nqlu24X?G1XL|F0W|yGFCJ-SE$P?T7A1b%l$(0 zuem}Isi@`(4LMh66wDQh_-d@C3Jv+!iAMQ;mZ++kLPMU3trC9E-{QB4Tfpp`D5xW& z<_Rr1PiSpdzE=pWt!jjx`xTJOH~u&O6hBZ)0|YGq000O8001EX9{8#s9h1KjNq;&Q z(smb`U1?%sjb?X*o^8ZJ383BHURJj{b%~ec|IU}g`HRKQwCDqT(48$Y%W{mss;<#? zQ{v}(N#+=ZC~YI#&O(Xz5HMO)lR`P6IuDP&bKs+ffRf1upyxfjY{|K+dWnH)pd8 z6wd<+6vc!~@CN_@O(T=!7dIORGN$fz1^@sBFaQ7%000000000100000005IV7$gNA z_^KZrlV2DW1Go_YlgSbslinB*lcN|K10oXulTj8LlaCk_lkONA13nZ0lTj8LlWG+W alQJ1913wi2lTj8alZY7*23r;Y0001C|NB`0 diff --git a/openAIMessages.m b/openAIMessages.m index 817f05f..82620e5 100644 --- a/openAIMessages.m +++ b/openAIMessages.m @@ -265,10 +265,11 @@ "name", toolCalls(i).function.name, ... "arguments", toolCalls(i).function.arguments); end - - newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct); - if numel(newMessage.tool_calls) == 1 - newMessage.tool_calls = {newMessage.tool_calls}; + if numel(toolsStruct) > 1 + newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct); + else + newMessage = struct("role", "assistant", "content", content, "tool_calls", []); + newMessage.tool_calls = {toolsStruct}; end end diff --git a/tests/topenAIMessages.m b/tests/topenAIMessages.m index c0790ff..4dfcc00 100644 --- a/tests/topenAIMessages.m +++ b/tests/topenAIMessages.m @@ -24,7 +24,7 @@ function differentInputTextAccepted(testCase, ValidTextInput) testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput)); - testCase.verifyWarningFree(@()addFunctionMessage(msgs, ValidTextInput, ValidTextInput)); + testCase.verifyWarningFree(@()addToolMessage(msgs, ValidTextInput, ValidTextInput, ValidTextInput)); end @@ -59,12 +59,13 @@ function userImageMessageIsAddedWithRemoteImg(testCase) testCase.verifyWarningFree(@()addUserMessageWithImages(msgs, prompt, img)); end - function functionMessageIsAdded(testCase) + function toolMessageIsAdded(testCase) prompt = "20"; name = "sin"; + id = "123"; msgs = openAIMessages; - systemPrompt = struct("role", "function", "name", name, "content", prompt); - msgs = addFunctionMessage(msgs, name, prompt); + systemPrompt = struct("tool_call_id", id, "role", "tool", "name", name, "content", prompt); + msgs = addToolMessage(msgs, id, name, prompt); testCase.verifyEqual(msgs.Messages{1}, systemPrompt); end @@ -76,27 +77,39 @@ function assistantMessageIsAdded(testCase) testCase.verifyEqual(msgs.Messages{1}, assistantPrompt); end - function assistantFunctionCallMessageIsAdded(testCase) + function assistantToolCallMessageIsAdded(testCase) msgs = openAIMessages; functionName = "functionName"; args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; funCall = struct("name", functionName, "arguments", args); toolCall = struct("id", "123", "type", "function", "function", funCall); - functionCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall); - functionCallPrompt.tool_calls = {functionCallPrompt.tool_calls}; - msgs = addResponseMessage(msgs, functionCallPrompt); - testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt); + toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); + toolCallPrompt.tool_calls = {toolCall}; + msgs = addResponseMessage(msgs, toolCallPrompt); + testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt); end - function assistantFunctionCallMessageWithoutArgsIsAdded(testCase) + function assistantToolCallMessageWithoutArgsIsAdded(testCase) msgs = openAIMessages; functionName = "functionName"; funCall = struct("name", functionName, "arguments", "{}"); toolCall = struct("id", "123", "type", "function", "function", funCall); - functionCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall); - functionCallPrompt.tool_calls = {functionCallPrompt.tool_calls}; - msgs = addResponseMessage(msgs, functionCallPrompt); - testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt); + toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", []); + toolCallPrompt.tool_calls = {toolCall}; + msgs = addResponseMessage(msgs, toolCallPrompt); + testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt); + end + + function assistantParallelToolCallMessageIsAdded(testCase) + msgs = openAIMessages; + functionName = "functionName"; + args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; + funCall = struct("name", functionName, "arguments", args); + toolCall = struct("id", "123", "type", "function", "function", funCall); + toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); + toolCallPrompt.tool_calls = [toolCall,toolCall,toolCall]; + msgs = addResponseMessage(msgs, toolCallPrompt); + testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt); end function messageGetsRemoved(testCase) @@ -105,7 +118,7 @@ function messageGetsRemoved(testCase) msgs = addSystemMessage(msgs, "name", "content"); msgs = addUserMessage(msgs, "content"); - msgs = addFunctionMessage(msgs, "name", "content"); + msgs = addToolMessage(msgs, "123", "name", "content"); sizeMsgs = length(msgs.Messages); % Message exists before removal msgToBeRemoved = msgs.Messages{idx}; @@ -121,7 +134,7 @@ function removalIdxCantBeLargerThanNumElements(testCase) msgs = addSystemMessage(msgs, "name", "content"); msgs = addUserMessage(msgs, "content"); - msgs = addFunctionMessage(msgs, "name", "content"); + msgs = addToolMessage(msgs, "123", "name", "content"); sizeMsgs = length(msgs.Messages); testCase.verifyError(@()removeMessage(msgs, sizeMsgs+1), "llms:mustBeValidIndex"); @@ -144,7 +157,7 @@ function invalidInputsUserImagesPrompt(testCase, InvalidInputsUserImagesPrompt) function invalidInputsFunctionPrompt(testCase, InvalidInputsFunctionPrompt) msgs = openAIMessages; - testCase.verifyError(@()addFunctionMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error); + testCase.verifyError(@()addToolMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error); end function invalidInputsRemove(testCase, InvalidRemoveMessage) @@ -231,27 +244,27 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage) function invalidFunctionPrompt = iGetInvalidFunctionPrompt invalidFunctionPrompt = struct( ... "NonStringInputName", ... - struct("Input", {{123, "content"}}, ... + struct("Input", {{"123", 123, "content"}}, ... "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... ... "NonStringInputContent", ... - struct("Input", {{"name", 123}}, ... + struct("Input", {{"123", "name", 123}}, ... "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... ... "EmptytName", ... - struct("Input", {{"", "content"}}, ... + struct("Input", {{"123", "", "content"}}, ... "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... ... "EmptytContent", ... - struct("Input", {{"name", ""}}, ... + struct("Input", {{"123", "name", ""}}, ... "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... ... "NonScalarInputName", ... - struct("Input", {{["name1" "name2"], "content"}}, ... + struct("Input", {{"123", ["name1" "name2"], "content"}}, ... "Error", "MATLAB:validators:mustBeTextScalar"),... ... "NonScalarInputContent", ... - struct("Input", {{"name", ["content1", "content2"]}}, ... + struct("Input", {{"123","name", ["content1", "content2"]}}, ... "Error", "MATLAB:validators:mustBeTextScalar")); end