Skip to content

Commit

Permalink
Merge pull request #6 from matlab-deep-learning/dev-parallelfunctionc…
Browse files Browse the repository at this point in the history
…all-jsonmode-vision-dalle

Updating the API to the newest version.
  • Loading branch information
debymf authored Jan 27, 2024
2 parents a4baf34 + 510002a commit 7e789c3
Show file tree
Hide file tree
Showing 19 changed files with 1,141 additions and 178 deletions.
42 changes: 29 additions & 13 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
function [text, message, response] = callOpenAIChatAPI(messages, functions, nvp)
% This function is undocumented and will change in a future release

%callOpenAIChatAPI Calls the openAI chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the OpenAI Chat Completions API.
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% Currently, the supported NVP are, including the equivalent name in the API:
% - FunctionCall (function_call)
% - ToolChoice (tool_choice)
% - ModelName (model)
% - Temperature (temperature)
% - TopProbabilityMass (top_p)
Expand All @@ -17,6 +15,8 @@
% - MaxNumTokens (max_tokens)
% - PresencePenalty (presence_penalty)
% - FrequencyPenalty (frequence_penalty)
% - ResponseFormat (response_format)
% - Seed (seed)
% - ApiKey
% - TimeOut
% - StreamFun
Expand Down Expand Up @@ -50,12 +50,12 @@
% % Send a request
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

arguments
messages
functions
nvp.FunctionCall = []
nvp.ToolChoice = []
nvp.ModelName = "gpt-3.5-turbo"
nvp.Temperature = 1
nvp.TopProbabilityMass = 1
Expand All @@ -64,6 +64,8 @@
nvp.MaxNumTokens = inf
nvp.PresencePenalty = 0
nvp.FrequencyPenalty = 0
nvp.ResponseFormat = "text"
nvp.Seed = []
nvp.ApiKey = ""
nvp.TimeOut = 10
nvp.StreamFun = []
Expand All @@ -85,7 +87,7 @@
message = struct("role", "assistant", ...
"content", streamedText);
end
if isfield(message, "function_call")
if isfield(message, "tool_choice")
text = "";
else
text = string(message.content);
Expand All @@ -105,22 +107,36 @@

parameters.stream = ~isempty(nvp.StreamFun);

if ~isempty(functions)
parameters.functions = functions;
if ~isempty(functions) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
parameters.tools = functions;
end

if ~isempty(nvp.ToolChoice) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
parameters.tool_choice = nvp.ToolChoice;
end

if ismember(nvp.ModelName,["gpt-3.5-turbo-1106","gpt-4-1106-preview"])
if strcmp(nvp.ResponseFormat,"json")
parameters.response_format = struct('type','json_object');
end
end

if ~isempty(nvp.FunctionCall)
parameters.function_call = nvp.FunctionCall;
if ~isempty(nvp.Seed)
parameters.seed = nvp.Seed;
end

parameters.model = nvp.ModelName;

dict = mapNVPToParameters;

nvpOptions = keys(dict);
for i=1:length(nvpOptions)
if isfield(nvp, nvpOptions(i))
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
nvpOptions(ismember(nvpOptions,["MaxNumTokens","StopSequences"])) = [];
end

for opt = nvpOptions.'
if isfield(nvp, opt)
parameters.(dict(opt)) = nvp.(opt);
end
end
end
Expand Down
21 changes: 12 additions & 9 deletions +llms/+utils/errorMessageCatalog.m
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
classdef errorMessageCatalog
% This class is undocumented and will change in a future release

%errorMessageCatalog Stores the error messages from this repository

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

properties(Constant)
%CATALOG dictionary mapping error ids to error msgs
Catalog = buildErrorMessageCatalog;
end

methods(Static)
function msg = getMessage(messageId, slot)
% This function is undocumented and will change in a future release

%getMessage returns error message given a messageID and a SLOT.
% The value in SLOT should be ordered, where the n-th element
% will replace the value "{n}".
Expand Down Expand Up @@ -41,13 +38,19 @@
catalog("llms:parameterMustBeUnique") = "A parameter name equivalent to '{1}' already exists in Parameters. Redefining a parameter is not allowed.";
catalog("llms:mustBeAssistantCall") = "Input struct must contain field 'role' with value 'assistant', and field 'content'.";
catalog("llms:mustBeAssistantWithContent") = "Input struct must contain field 'content' containing text with one or more characters.";
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function_call' must be a struct with fields 'name' and 'arguments'.";
catalog("llms:mustBeAssistantWithIdAndFunction") = "Field 'tool_call' must be a struct with fields 'id' and 'function'.";
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function' must be a struct with fields 'name' and 'arguments'.";
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter.";
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, FunctionCall must not be specified.";
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified.";
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";
end

catalog("llms:invalidOptionAndValueForModel") = "'{1}' with value '{2}' is not supported for ModelName '{3}'";
catalog("llms:invalidOptionForModel") = "{1} is not supported for ModelName '{2}'";
catalog("llms:functionNotAvailableForModel") = "This function is not supported for ModelName '{1}'";
catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'";
catalog("llms:pngExpected") = "Argument must be a PNG image.";
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
end
28 changes: 28 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
*.fig binary
*.mat binary
*.mdl binary diff merge=mlAutoMerge
*.mdlp binary
*.mexa64 binary
*.mexw64 binary
*.mexmaci64 binary
*.mlapp binary
*.mldatx binary
*.mlproj binary
*.mlx binary
*.p binary
*.sfx binary
*.sldd binary
*.slreqx binary merge=mlAutoMerge
*.slmx binary merge=mlAutoMerge
*.sltx binary
*.slxc binary
*.slx binary merge=mlAutoMerge
*.slxp binary

## Other common binary file types
*.docx binary
*.exe binary
*.jpg binary
*.pdf binary
*.png binary
*.xlsx binary
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.env
*.asv
startup.m
Loading

0 comments on commit 7e789c3

Please sign in to comment.