Skip to content

Commit

Permalink
minor bugfixes - but the test still fails
Browse files Browse the repository at this point in the history
  • Loading branch information
toshiakit committed Jan 18, 2024
1 parent ec0d493 commit 3bf8430
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 27 deletions.
Binary file modified examples/ExampleParallelFunctionCalls.mlx
Binary file not shown.
9 changes: 5 additions & 4 deletions openAIMessages.m
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,11 @@
"name", toolCalls(i).function.name, ...
"arguments", toolCalls(i).function.arguments);
end

newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct);
if numel(newMessage.tool_calls) == 1
newMessage.tool_calls = {newMessage.tool_calls};
if numel(toolsStruct) > 1
newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct);
else
newMessage = struct("role", "assistant", "content", content, "tool_calls", []);
newMessage.tool_calls = {toolsStruct};
end
end

Expand Down
59 changes: 36 additions & 23 deletions tests/topenAIMessages.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function differentInputTextAccepted(testCase, ValidTextInput)
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput));
testCase.verifyWarningFree(@()addFunctionMessage(msgs, ValidTextInput, ValidTextInput));
testCase.verifyWarningFree(@()addToolMessage(msgs, ValidTextInput, ValidTextInput, ValidTextInput));
end


Expand Down Expand Up @@ -59,12 +59,13 @@ function userImageMessageIsAddedWithRemoteImg(testCase)
testCase.verifyWarningFree(@()addUserMessageWithImages(msgs, prompt, img));
end

function functionMessageIsAdded(testCase)
function toolMessageIsAdded(testCase)
prompt = "20";
name = "sin";
id = "123";
msgs = openAIMessages;
systemPrompt = struct("role", "function", "name", name, "content", prompt);
msgs = addFunctionMessage(msgs, name, prompt);
systemPrompt = struct("tool_call_id", id, "role", "tool", "name", name, "content", prompt);
msgs = addToolMessage(msgs, id, name, prompt);
testCase.verifyEqual(msgs.Messages{1}, systemPrompt);
end

Expand All @@ -76,27 +77,39 @@ function assistantMessageIsAdded(testCase)
testCase.verifyEqual(msgs.Messages{1}, assistantPrompt);
end

function assistantFunctionCallMessageIsAdded(testCase)
function assistantToolCallMessageIsAdded(testCase)
msgs = openAIMessages;
functionName = "functionName";
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
funCall = struct("name", functionName, "arguments", args);
toolCall = struct("id", "123", "type", "function", "function", funCall);
functionCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall);
functionCallPrompt.tool_calls = {functionCallPrompt.tool_calls};
msgs = addResponseMessage(msgs, functionCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []);
toolCallPrompt.tool_calls = {toolCall};
msgs = addResponseMessage(msgs, toolCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);
end

function assistantFunctionCallMessageWithoutArgsIsAdded(testCase)
function assistantToolCallMessageWithoutArgsIsAdded(testCase)
msgs = openAIMessages;
functionName = "functionName";
funCall = struct("name", functionName, "arguments", "{}");
toolCall = struct("id", "123", "type", "function", "function", funCall);
functionCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall);
functionCallPrompt.tool_calls = {functionCallPrompt.tool_calls};
msgs = addResponseMessage(msgs, functionCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", []);
toolCallPrompt.tool_calls = {toolCall};
msgs = addResponseMessage(msgs, toolCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);
end

function assistantParallelToolCallMessageIsAdded(testCase)
msgs = openAIMessages;
functionName = "functionName";
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
funCall = struct("name", functionName, "arguments", args);
toolCall = struct("id", "123", "type", "function", "function", funCall);
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []);
toolCallPrompt.tool_calls = [toolCall,toolCall,toolCall];
msgs = addResponseMessage(msgs, toolCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);
end

function messageGetsRemoved(testCase)
Expand All @@ -105,7 +118,7 @@ function messageGetsRemoved(testCase)

msgs = addSystemMessage(msgs, "name", "content");
msgs = addUserMessage(msgs, "content");
msgs = addFunctionMessage(msgs, "name", "content");
msgs = addToolMessage(msgs, "123", "name", "content");
sizeMsgs = length(msgs.Messages);
% Message exists before removal
msgToBeRemoved = msgs.Messages{idx};
Expand All @@ -121,7 +134,7 @@ function removalIdxCantBeLargerThanNumElements(testCase)

msgs = addSystemMessage(msgs, "name", "content");
msgs = addUserMessage(msgs, "content");
msgs = addFunctionMessage(msgs, "name", "content");
msgs = addToolMessage(msgs, "123", "name", "content");
sizeMsgs = length(msgs.Messages);

testCase.verifyError(@()removeMessage(msgs, sizeMsgs+1), "llms:mustBeValidIndex");
Expand All @@ -144,7 +157,7 @@ function invalidInputsUserImagesPrompt(testCase, InvalidInputsUserImagesPrompt)

function invalidInputsFunctionPrompt(testCase, InvalidInputsFunctionPrompt)
msgs = openAIMessages;
testCase.verifyError(@()addFunctionMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error);
testCase.verifyError(@()addToolMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error);
end

function invalidInputsRemove(testCase, InvalidRemoveMessage)
Expand Down Expand Up @@ -231,27 +244,27 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
function invalidFunctionPrompt = iGetInvalidFunctionPrompt
invalidFunctionPrompt = struct( ...
"NonStringInputName", ...
struct("Input", {{123, "content"}}, ...
struct("Input", {{"123", 123, "content"}}, ...
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
...
"NonStringInputContent", ...
struct("Input", {{"name", 123}}, ...
struct("Input", {{"123", "name", 123}}, ...
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
...
"EmptytName", ...
struct("Input", {{"", "content"}}, ...
struct("Input", {{"123", "", "content"}}, ...
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
...
"EmptytContent", ...
struct("Input", {{"name", ""}}, ...
struct("Input", {{"123", "name", ""}}, ...
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
...
"NonScalarInputName", ...
struct("Input", {{["name1" "name2"], "content"}}, ...
struct("Input", {{"123", ["name1" "name2"], "content"}}, ...
"Error", "MATLAB:validators:mustBeTextScalar"),...
...
"NonScalarInputContent", ...
struct("Input", {{"name", ["content1", "content2"]}}, ...
struct("Input", {{"123","name", ["content1", "content2"]}}, ...
"Error", "MATLAB:validators:mustBeTextScalar"));
end

Expand Down

0 comments on commit 3bf8430

Please sign in to comment.