diff --git a/+llms/+internal/callAzureChatAPI.m b/+llms/+internal/callAzureChatAPI.m new file mode 100644 index 0000000..7791001 --- /dev/null +++ b/+llms/+internal/callAzureChatAPI.m @@ -0,0 +1,138 @@ +function [text, message, response] = callAzureChatAPI(resourceName, deploymentID, messages, functions, nvp) +%callOpenAIChatAPI Calls the openAI chat completions API. +% +% MESSAGES and FUNCTIONS should be structs matching the json format +% required by the OpenAI Chat Completions API. +% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api +% +% Currently, the supported NVP are, including the equivalent name in the API: +% - ToolChoice (tool_choice) +% - Temperature (temperature) +% - TopProbabilityMass (top_p) +% - NumCompletions (n) +% - StopSequences (stop) +% - MaxNumTokens (max_tokens) +% - PresencePenalty (presence_penalty) +% - FrequencyPenalty (frequence_penalty) +% - ResponseFormat (response_format) +% - Seed (seed) +% - ApiKey +% - TimeOut +% - StreamFun +% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create +% +% Example +% +% % Create messages struct +% messages = {struct("role", "system",... +% "content", "You are a helpful assistant"); +% struct("role", "user", ... +% "content", "What is the edit distance between hi and hello?")}; +% +% % Create functions struct +% functions = {struct("name", "editDistance", ... +% "description", "Find edit distance between two strings or documents.", ... +% "parameters", struct( ... +% "type", "object", ... +% "properties", struct(... +% "str1", struct(... +% "description", "Source string.", ... +% "type", "string"),... +% "str2", struct(... +% "description", "Target string.", ... +% "type", "string")),... +% "required", ["str1", "str2"]))}; +% +% % Define your API key +% apiKey = "your-api-key-here" +% +% % Send a request +% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey) + +% Copyright 2023-2024 The MathWorks, Inc. + +arguments + resourceName + deploymentID + messages + functions + nvp.ToolChoice = [] + nvp.APIVersion = "2023-05-15" + nvp.Temperature = 1 + nvp.TopProbabilityMass = 1 + nvp.NumCompletions = 1 + nvp.StopSequences = [] + nvp.MaxNumTokens = inf + nvp.PresencePenalty = 0 + nvp.FrequencyPenalty = 0 + nvp.ResponseFormat = "text" + nvp.Seed = [] + nvp.ApiKey = "" + nvp.TimeOut = 10 + nvp.StreamFun = [] +end + +END_POINT = "https://" + resourceName + ".openai.azure.com/openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion; + +parameters = buildParametersCall(messages, functions, nvp); + +[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun); + +% If call errors, "choices" will not be part of response.Body.Data, instead +% we get response.Body.Data.error +if response.StatusCode=="OK" + % Outputs the first generation + if isempty(nvp.StreamFun) + message = response.Body.Data.choices(1).message; + else + message = struct("role", "assistant", ... + "content", streamedText); + end + if isfield(message, "tool_choice") + text = ""; + else + text = string(message.content); + end +else + text = ""; + message = struct(); +end +end + +function parameters = buildParametersCall(messages, functions, nvp) +% Builds a struct in the format that is expected by the API, combining +% MESSAGES, FUNCTIONS and parameters in NVP. + +parameters = struct(); +parameters.messages = messages; + +parameters.stream = ~isempty(nvp.StreamFun); + +parameters.tools = functions; + +parameters.tool_choice = nvp.ToolChoice; + +if ~isempty(nvp.Seed) + parameters.seed = nvp.Seed; +end + +dict = mapNVPToParameters; + +nvpOptions = keys(dict); +for opt = nvpOptions.' + if isfield(nvp, opt) + parameters.(dict(opt)) = nvp.(opt); + end +end +end + +function dict = mapNVPToParameters() +dict = dictionary(); +dict("Temperature") = "temperature"; +dict("TopProbabilityMass") = "top_p"; +dict("NumCompletions") = "n"; +dict("StopSequences") = "stop"; +dict("MaxNumTokens") = "max_tokens"; +dict("PresencePenalty") = "presence_penalty"; +dict("FrequencyPenalty") = "frequency_penalty"; +end \ No newline at end of file diff --git a/+llms/+internal/textGenerator.m b/+llms/+internal/textGenerator.m new file mode 100644 index 0000000..cb65a8e --- /dev/null +++ b/+llms/+internal/textGenerator.m @@ -0,0 +1,90 @@ +classdef (Abstract) textGenerator + + properties + %TEMPERATURE Temperature of generation. + Temperature + + %TOPPROBABILITYMASS Top probability mass to consider for generation. + TopProbabilityMass + + %STOPSEQUENCES Sequences to stop the generation of tokens. + StopSequences + + %PRESENCEPENALTY Penalty for using a token in the response that has already been used. + PresencePenalty + + %FREQUENCYPENALTY Penalty for using a token that is frequent in the training data. + FrequencyPenalty + end + + properties (SetAccess=protected) + %TIMEOUT Connection timeout in seconds (default 10 secs) + TimeOut + + %FUNCTIONNAMES Names of the functions that the model can request calls + FunctionNames + + %SYSTEMPROMPT System prompt. + SystemPrompt = [] + + %RESPONSEFORMAT Response format, "text" or "json" + ResponseFormat + end + + properties (Access=protected) + Tools + FunctionsStruct + ApiKey + StreamFun + end + + + methods + function this = set.Temperature(this, temperature) + arguments + this + temperature + end + llms.utils.mustBeValidTemperature(temperature); + this.Temperature = temperature; + end + + function this = set.TopProbabilityMass(this,topP) + arguments + this + topP + end + llms.utils.mustBeValidTopP(topP); + this.TopProbabilityMass = topP; + end + + function this = set.StopSequences(this,stop) + arguments + this + stop + end + llms.utils.mustBeValidStop(stop); + this.StopSequences = stop; + end + + function this = set.PresencePenalty(this,penalty) + arguments + this + penalty + end + llms.utils.mustBeValidPenalty(penalty) + this.PresencePenalty = penalty; + end + + function this = set.FrequencyPenalty(this,penalty) + arguments + this + penalty + end + llms.utils.mustBeValidPenalty(penalty) + this.FrequencyPenalty = penalty; + end + + end + +end \ No newline at end of file diff --git a/+llms/+utils/errorMessageCatalog.m b/+llms/+utils/errorMessageCatalog.m index 3791319..c9b21d1 100644 --- a/+llms/+utils/errorMessageCatalog.m +++ b/+llms/+utils/errorMessageCatalog.m @@ -53,4 +53,7 @@ catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'"; catalog("llms:pngExpected") = "Argument must be a PNG image."; catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message."; +catalog("llms:invalidOptionsForOpenAIBackEnd") = "The parameters Resource Name, Deployment ID and API Version are not compatible with OpenAI."; +catalog("llms:invalidOptionsForAzureBackEnd") = "The parameter Model Name is not compatible with Azure."; + end \ No newline at end of file diff --git a/+llms/+utils/mustBeValidPenalty.m b/+llms/+utils/mustBeValidPenalty.m new file mode 100644 index 0000000..be83e16 --- /dev/null +++ b/+llms/+utils/mustBeValidPenalty.m @@ -0,0 +1,3 @@ +function mustBeValidPenalty(value) + validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonsparse', '<=', 2, '>=', -2}) +end \ No newline at end of file diff --git a/+llms/+utils/mustBeValidStop.m b/+llms/+utils/mustBeValidStop.m new file mode 100644 index 0000000..187dfae --- /dev/null +++ b/+llms/+utils/mustBeValidStop.m @@ -0,0 +1,10 @@ +function mustBeValidStop(value) + if ~isempty(value) + mustBeVector(value); + mustBeNonzeroLengthText(value); + % This restriction is set by the OpenAI API + if numel(value)>4 + error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements")); + end + end +end \ No newline at end of file diff --git a/+llms/+utils/mustBeValidTemperature.m b/+llms/+utils/mustBeValidTemperature.m new file mode 100644 index 0000000..1ab6604 --- /dev/null +++ b/+llms/+utils/mustBeValidTemperature.m @@ -0,0 +1,3 @@ +function mustBeValidTemperature(value) + validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 2}) +end \ No newline at end of file diff --git a/+llms/+utils/mustBeValidTopP.m b/+llms/+utils/mustBeValidTopP.m new file mode 100644 index 0000000..ffa4aba --- /dev/null +++ b/+llms/+utils/mustBeValidTopP.m @@ -0,0 +1,3 @@ +function mustBeValidTopP(value) + validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 1}) +end \ No newline at end of file diff --git a/README.md b/README.md index b1fdcf9..689660e 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,35 @@ messages = addUserMessageWithImages(messages,"What is in the image?",image_path) % Should output the description of the image ``` +## Establishing a connection to Chat Completions API using AzureĀ® + +If you would like to connect MATLAB to Chat Completions API via AzureĀ® instead of directly with OpenAI, you will have to create an `azureChat` object. +However, you first need to obtain, in addition to the Azure API keys, your Azure OpenAI Resource. + +In order to create the chat assistant, you must specify your Azure OpenAI Resource and the LLM you want to use: +```matlab +chat = azureChat(YOUR_RESOURCE_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant"); +``` + +The `azureChat` object also allows to specify additional options in the same way as the `openAIChat` object. +However, the `ModelName` option is not available due to the fact that the name of the LLM is already specified when creating the chat assistant. + +On the other hand, the `azureChat` object offers an additional option that allows you to set the API version that you want to use for the operation. + +After establishing your connection with Azure, you can continue using the `generate` function and other objects in the same way as if you had established a connection directly with OpenAI: +```matlab +% Initialize the Azure Chat object, passing a system prompt and specifying the API version +chat = azureChat(YOUR_RESOURCE_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant", APIVersion="2023-12-01-preview"); + +% Create an openAIMessages object to start the conversation history +history = openAIMessages; + +% Ask your question and store it in the history, create the response using the generate method, and store the response in the history +history = addUserMessage(history,"What is an eigenvalue?"); +[txt, response] = generate(chat, history) +history = addResponseMessage(history, response); +``` + ### Obtaining embeddings You can extract embeddings from your text with OpenAI using the function `extractOpenAIEmbeddings` as follows: diff --git a/azureChat.m b/azureChat.m new file mode 100644 index 0000000..0ede901 --- /dev/null +++ b/azureChat.m @@ -0,0 +1,258 @@ +classdef(Sealed) azureChat < llms.internal.textGenerator +%azureChat Chat completion API from Azure. +% +% CHAT = azureChat(resourceName, deploymentID) creates an azureChat object with the +% resource name and deployment ID path parameters required by Azure to establish the connection. +% +% CHAT = azureChat(systemPrompt) creates an azureChatobject with the +% specified system prompt. +% +% CHAT = azureChat(systemPrompt,Name=Value) specifies additional options +% using one or more name-value arguments: +% +% Tools - A list of tools the model can call. +% This parameter requires API version 2023-12-01-preview. +% +% API Version - A list of API versions to use for this operation. +% Default value is 2023-05-15. +% +% Temperature - Temperature value for controlling the randomness +% of the output. Default value is 1. +% +% TopProbabilityMass - Top probability mass value for controlling the +% diversity of the output. Default value is 1. +% +% StopSequences - Vector of strings that when encountered, will +% stop the generation of tokens. Default +% value is empty. +% +% ResponseFormat - The format of response the model returns. +% "text" (default) | "json" +% +% ApiKey - The API key for accessing the OpenAI Chat API. +% +% PresencePenalty - Penalty value for using a token in the response +% that has already been used. Default value is 0. +% +% FrequencyPenalty - Penalty value for using a token that is frequent +% in the training data. Default value is 0. +% +% StreamFun - Function to callback when streaming the +% result +% +% TimeOut - Connection Timeout in seconds (default: 10 secs) +% +% +% +% azureChat Functions: +% azureChat - Chat completion API from OpenAI. +% generate - Generate a response using the azureChat instance. +% +% azureChat Properties: +% Temperature - Temperature of generation. +% +% TopProbabilityMass - Top probability mass to consider for generation. +% +% StopSequences - Sequences to stop the generation of tokens. +% +% PresencePenalty - Penalty for using a token in the +% response that has already been used. +% +% FrequencyPenalty - Penalty for using a token that is +% frequent in the training data. +% +% SystemPrompt - System prompt. +% +% FunctionNames - Names of the functions that the model can +% request calls. +% +% ResponseFormat - Specifies the response format, text or json +% +% TimeOut - Connection Timeout in seconds (default: 10 secs) +% + +% Copyright 2023-2024 The MathWorks, Inc. + + properties(SetAccess=private) + ResourceName + DeploymentID + APIVersion + end + + + methods + function this = azureChat(resourceName, deploymentID, systemPrompt, nvp) + arguments + resourceName {mustBeTextScalar} + deploymentID {mustBeTextScalar} + systemPrompt {llms.utils.mustBeTextOrEmpty} = [] + nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty + nvp.APIVersion (1,1) {mustBeMember(nvp.APIVersion,["2023-03-15-preview", "2023-05-15", "2023-06-01-preview", ... + "2023-07-01-preview", "2023-08-01-preview",... + "2023-12-01-preview"])} = "2023-05-15" + nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 + nvp.TopProbabilityMass {llms.utils.mustBeValidTopP} = 1 + nvp.StopSequences {llms.utils.mustBeValidStop} = {} + nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" + nvp.ApiKey {mustBeNonzeroLengthTextScalar} + nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0 + nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 + nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 + nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} + end + + if isfield(nvp,"StreamFun") + this.StreamFun = nvp.StreamFun; + else + this.StreamFun = []; + end + + if isempty(nvp.Tools) + this.Tools = []; + this.FunctionsStruct = []; + this.FunctionNames = []; + else + this.Tools = nvp.Tools; + [this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools); + end + + if ~isempty(systemPrompt) + systemPrompt = string(systemPrompt); + if ~(strlength(systemPrompt)==0) + this.SystemPrompt = {struct("role", "system", "content", systemPrompt)}; + end + end + + this.ResourceName = resourceName; + this.DeploymentID = deploymentID; + this.APIVersion = nvp.APIVersion; + this.ResponseFormat = nvp.ResponseFormat; + this.Temperature = nvp.Temperature; + this.TopProbabilityMass = nvp.TopProbabilityMass; + this.StopSequences = nvp.StopSequences; + this.PresencePenalty = nvp.PresencePenalty; + this.FrequencyPenalty = nvp.FrequencyPenalty; + this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp); + this.TimeOut = nvp.TimeOut; + end + + function [text, message, response] = generate(this, messages, nvp) + %generate Generate a response using the azureChat instance. + % + % [TEXT, MESSAGE, RESPONSE] = generate(CHAT, MESSAGES) generates a response + % with the specified MESSAGES. + % + % [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options + % using one or more name-value arguments: + % + % NumCompletions - Number of completions to generate. + % Default value is 1. + % + % MaxNumTokens - Maximum number of tokens in the generated response. + % Default value is inf. + % + % ToolChoice - Function to execute. 'none', 'auto', + % or specify the function to call. + % + % Seed - An integer value to use to obtain + % reproducible responses + % + % Currently, GPT-4 Turbo with vision does not support the message.name + % parameter, functions/tools, response_format parameter, stop + % sequences, and max_tokens + + arguments + this (1,1) azureChat + messages (1,1) {mustBeValidMsgs} + nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1 + nvp.MaxNumTokens (1,1) {mustBePositive} = inf + nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = [] + nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = [] + end + + if isstring(messages) && isscalar(messages) + messagesStruct = {struct("role", "user", "content", messages)}; + else + messagesStruct = messages.Messages; + end + + if ~isempty(this.SystemPrompt) + messagesStruct = horzcat(this.SystemPrompt, messagesStruct); + end + + toolChoice = convertToolChoice(this, nvp.ToolChoice); + [text, message, response] = llms.internal.callAzureChatAPI(this.ResourceName, ... + this.DeploymentID, messagesStruct, this.FunctionsStruct, ... + ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ... + TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,... + StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... + PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ... + ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... + ApiKey=this.ApiKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun); + end + end + + methods(Hidden) + function mustBeValidFunctionCall(this, functionCall) + if ~isempty(functionCall) + mustBeTextScalar(functionCall); + if isempty(this.FunctionNames) + error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall")); + end + mustBeMember(functionCall, ["none","auto", this.FunctionNames]); + end + end + + function toolChoice = convertToolChoice(this, toolChoice) + % if toolChoice is empty + if isempty(toolChoice) + % if Tools is not empty, the default is 'auto'. + if ~isempty(this.Tools) + toolChoice = "auto"; + end + elseif ToolChoice ~= "auto" + % if toolChoice is not empty, then it must be in the format + % {"type": "function", "function": {"name": "my_function"}} + toolChoice = struct("type","function","function",struct("name",toolChoice)); + end + + end + end +end + +function mustBeNonzeroLengthTextScalar(content) +mustBeNonzeroLengthText(content) +mustBeTextScalar(content) +end + +function [functionsStruct, functionNames] = functionAsStruct(functions) +numFunctions = numel(functions); +functionsStruct = cell(1, numFunctions); +functionNames = strings(1, numFunctions); + +for i = 1:numFunctions + functionsStruct{i} = struct('type','function', ... + 'function',encodeStruct(functions(i))) ; + functionNames(i) = functions(i).FunctionName; +end +end + +function mustBeValidMsgs(value) +if isa(value, "openAIMessages") + if numel(value.Messages) == 0 + error("llms:mustHaveMessages", llms.utils.errorMessageCatalog.getMessage("llms:mustHaveMessages")); + end +else + try + llms.utils.mustBeNonzeroLengthTextScalar(value); + catch ME + error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt")); + end +end +end + +function mustBeIntegerOrEmpty(value) + if ~isempty(value) + mustBeInteger(value) + end +end \ No newline at end of file diff --git a/openAIChat.m b/openAIChat.m index fc15fbb..389ceab 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -1,4 +1,4 @@ -classdef(Sealed) openAIChat +classdef(Sealed) openAIChat < llms.internal.textGenerator %openAIChat Chat completion API from OpenAI. % % CHAT = openAIChat(systemPrompt) creates an openAIChat object with the @@ -68,46 +68,12 @@ % Copyright 2023-2024 The MathWorks, Inc. - properties - %TEMPERATURE Temperature of generation. - Temperature - - %TOPPROBABILITYMASS Top probability mass to consider for generation. - TopProbabilityMass - - %STOPSEQUENCES Sequences to stop the generation of tokens. - StopSequences - - %PRESENCEPENALTY Penalty for using a token in the response that has already been used. - PresencePenalty - - %FREQUENCYPENALTY Penalty for using a token that is frequent in the training data. - FrequencyPenalty - end properties(SetAccess=private) - %TIMEOUT Connection timeout in seconds (default 10 secs) - TimeOut - - %FUNCTIONNAMES Names of the functions that the model can request calls - FunctionNames - %MODELNAME Model name. ModelName - - %SYSTEMPROMPT System prompt. - SystemPrompt = [] - - %RESPONSEFORMAT Response format, "text" or "json" - ResponseFormat end - properties(Access=private) - Tools - FunctionsStruct - ApiKey - StreamFun - end methods function this = openAIChat(systemPrompt, nvp) @@ -117,14 +83,14 @@ nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", ... "gpt-3.5-turbo", "gpt-3.5-turbo-16k",... "gpt-4-1106-preview","gpt-3.5-turbo-1106", ... - "gpt-4-vision-preview", "gpt-4-turbo-preview"])} = "gpt-3.5-turbo" - nvp.Temperature {mustBeValidTemperature} = 1 - nvp.TopProbabilityMass {mustBeValidTopP} = 1 - nvp.StopSequences {mustBeValidStop} = {} + "gpt-4-vision-preview", "gpt-4-turbo-preview"])} = "gpt-3.5-turbo" + nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 + nvp.TopProbabilityMass {llms.utils.mustBeValidTopP} = 1 + nvp.StopSequences {llms.utils.mustBeValidStop} = {} nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" nvp.ApiKey {mustBeNonzeroLengthTextScalar} - nvp.PresencePenalty {mustBeValidPenalty} = 0 - nvp.FrequencyPenalty {mustBeValidPenalty} = 0 + nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0 + nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} end @@ -251,51 +217,6 @@ ApiKey=this.ApiKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun); end - function this = set.Temperature(this, temperature) - arguments - this openAIChat - temperature - end - mustBeValidTemperature(temperature); - - this.Temperature = temperature; - end - - function this = set.TopProbabilityMass(this,topP) - arguments - this openAIChat - topP - end - mustBeValidTopP(topP); - this.TopProbabilityMass = topP; - end - - function this = set.StopSequences(this,stop) - arguments - this openAIChat - stop - end - mustBeValidStop(stop); - this.StopSequences = stop; - end - - function this = set.PresencePenalty(this,penalty) - arguments - this openAIChat - penalty - end - mustBeValidPenalty(penalty) - this.PresencePenalty = penalty; - end - - function this = set.FrequencyPenalty(this,penalty) - arguments - this openAIChat - penalty - end - mustBeValidPenalty(penalty) - this.FrequencyPenalty = penalty; - end end methods(Hidden) @@ -357,29 +278,6 @@ function mustBeValidMsgs(value) end end -function mustBeValidPenalty(value) -validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonsparse', '<=', 2, '>=', -2}) -end - -function mustBeValidTopP(value) -validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 1}) -end - -function mustBeValidTemperature(value) -validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 2}) -end - -function mustBeValidStop(value) -if ~isempty(value) - mustBeVector(value); - mustBeNonzeroLengthText(value); - % This restriction is set by the OpenAI API - if numel(value)>4 - error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements")); - end -end -end - function mustBeIntegerOrEmpty(value) if ~isempty(value) mustBeInteger(value) diff --git a/tests/tazureChat.m b/tests/tazureChat.m new file mode 100644 index 0000000..c483baf --- /dev/null +++ b/tests/tazureChat.m @@ -0,0 +1,375 @@ +classdef tazureChat < matlab.unittest.TestCase +% Tests for azureChat + +% Copyright 2024 The MathWorks, Inc. + + methods (TestClassSetup) + function saveEnvVar(testCase) + % Ensures key is not in environment variable for tests + openAIEnvVar = "OPENAI_API_KEY"; + if isenv(openAIEnvVar) + key = getenv(openAIEnvVar); + unsetenv(openAIEnvVar); + testCase.addTeardown(@(x) setenv(openAIEnvVar, x), key); + end + end + end + + properties(TestParameter) + InvalidConstructorInput = iGetInvalidConstructorInput; + InvalidGenerateInput = iGetInvalidGenerateInput; + InvalidValuesSetters = iGetInvalidValuesSetters; + end + + methods(Test) + % Test methods + function keyNotFound(testCase) + testCase.verifyError(@()azureChat("My_resource", "Deployment"), "llms:keyMustBeSpecified"); + end + + function constructChatWithAllNVP(testCase) + resourceName = "resource"; + deploymentID = "hello"; + functions = openAIFunction("funName"); + temperature = 0; + topP = 1; + stop = ["[END]", "."]; + apiKey = "this-is-not-a-real-key"; + presenceP = -2; + frequenceP = 2; + systemPrompt = "This is a system prompt"; + timeout = 3; + chat = azureChat(resourceName, deploymentID, systemPrompt, Tools=functions, ... + Temperature=temperature, TopProbabilityMass=topP, StopSequences=stop, ApiKey=apiKey,... + FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout); + testCase.verifyEqual(chat.Temperature, temperature); + testCase.verifyEqual(chat.TopProbabilityMass, topP); + testCase.verifyEqual(chat.StopSequences, stop); + testCase.verifyEqual(chat.FrequencyPenalty, frequenceP); + testCase.verifyEqual(chat.PresencePenalty, presenceP); + end + + function verySmallTimeOutErrors(testCase) + chat = azureChat("My_resource", "Deployment", TimeOut=0.0001, ApiKey="false-key"); + testCase.verifyError(@()generate(chat, "hi"), "MATLAB:webservices:Timeout") + end + + function errorsWhenPassingToolChoiceWithEmptyTools(testCase) + chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key"); + testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall"); + end + + function invalidInputsConstructor(testCase, InvalidConstructorInput) + testCase.verifyError(@()azureChat("My_resource", "Deployment", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error); + end + + function invalidInputsGenerate(testCase, InvalidGenerateInput) + f = openAIFunction("validfunction"); + chat = azureChat("My_resource", "Deployment", Tools=f, ApiKey="this-is-not-a-real-key"); + testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error); + end + + function invalidSetters(testCase, InvalidValuesSetters) + chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key"); + function assignValueToProperty(property, value) + chat.(property) = value; + end + + testCase.verifyError(@()assignValueToProperty(InvalidValuesSetters.Property,InvalidValuesSetters.Value), InvalidValuesSetters.Error); + end + end +end + +function invalidValuesSetters = iGetInvalidValuesSetters + +invalidValuesSetters = struct( ... + "InvalidTemperatureType", struct( ... + "Property", "Temperature", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidTemperatureSize", struct( ... + "Property", "Temperature", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "TemperatureTooLarge", struct( ... + "Property", "Temperature", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "TemperatureTooSmall", struct( ... + "Property", "Temperature", ... + "Value", -20, ... + "Error", "MATLAB:expectedNonnegative"), ... + ... + "InvalidTopProbabilityMassType", struct( ... + "Property", "TopProbabilityMass", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidTopProbabilityMassSize", struct( ... + "Property", "TopProbabilityMass", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "TopProbabilityMassTooLarge", struct( ... + "Property", "TopProbabilityMass", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "TopProbabilityMassTooSmall", struct( ... + "Property", "TopProbabilityMass", ... + "Value", -20, ... + "Error", "MATLAB:expectedNonnegative"), ... + ... + "WrongTypeStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", 123, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "WrongSizeStopNonVector", struct( ... + "Property", "StopSequences", ... + "Value", repmat("stop", 4), ... + "Error", "MATLAB:validators:mustBeVector"), ... + ... + "EmptyStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", "", ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "WrongSizeStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", ["1" "2" "3" "4" "5"], ... + "Error", "llms:stopSequencesMustHaveMax4Elements"), ... + ... + "InvalidPresencePenalty", struct( ... + "Property", "PresencePenalty", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidPresencePenaltySize", struct( ... + "Property", "PresencePenalty", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "PresencePenaltyTooLarge", struct( ... + "Property", "PresencePenalty", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "PresencePenaltyTooSmall", struct( ... + "Property", "PresencePenalty", ... + "Value", -20, ... + "Error", "MATLAB:notGreaterEqual"), ... + ... + "InvalidFrequencyPenalty", struct( ... + "Property", "FrequencyPenalty", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidFrequencyPenaltySize", struct( ... + "Property", "FrequencyPenalty", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "FrequencyPenaltyTooLarge", struct( ... + "Property", "FrequencyPenalty", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "FrequencyPenaltyTooSmall", struct( ... + "Property", "FrequencyPenalty", ... + "Value", -20, ... + "Error", "MATLAB:notGreaterEqual")); +end + +function invalidConstructorInput = iGetInvalidConstructorInput +validFunction = openAIFunction("funName"); +invalidConstructorInput = struct( ... + "InvalidResponseFormatValue", struct( ... + "Input",{{"ResponseFormat", "foo" }},... + "Error", "MATLAB:validators:mustBeMember"), ... + ... + "InvalidResponseFormatSize", struct( ... + "Input",{{"ResponseFormat", ["text" "text"] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "InvalidStreamFunType", struct( ... + "Input",{{"StreamFun", "2" }},... + "Error", "MATLAB:validators:mustBeA"), ... + ... + "InvalidStreamFunSize", struct( ... + "Input",{{"StreamFun", [1 1 1] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "InvalidTimeOutType", struct( ... + "Input",{{"TimeOut", "2" }},... + "Error", "MATLAB:validators:mustBeReal"), ... + ... + "InvalidTimeOutSize", struct( ... + "Input",{{"TimeOut", [1 1 1] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "WrongTypeSystemPrompt",struct( ... + "Input",{{ 123 }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "WrongSizeSystemPrompt",struct( ... + "Input",{{ ["test"; "test"] }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "InvalidToolsType",struct( ... + "Input",{{"Tools", "a" }},... + "Error","MATLAB:validators:mustBeA"),... + ... + "InvalidToolsSize",struct( ... + "Input",{{"Tools", repmat(validFunction, 2, 2) }},... + "Error","MATLAB:validation:IncompatibleSize"),... + ... + "InvalidAPIVersionType",struct( ... + "Input",{{"APIVersion", 0}},... + "Error","MATLAB:validators:mustBeMember"),... + ... + "InvalidAPIVersionSize",struct( ... + "Input",{{"APIVersion", ["2023-05-15", "2023-05-15"]}},... + "Error","MATLAB:validation:IncompatibleSize"),... + ... + "InvalidAPIVersionOption",struct( ... + "Input",{{ "APIVersion", "gpt" }},... + "Error","MATLAB:validators:mustBeMember"),... + ... + "InvalidTemperatureType",struct( ... + "Input",{{ "Temperature" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidTemperatureSize",struct( ... + "Input",{{ "Temperature" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "TemperatureTooLarge",struct( ... + "Input",{{ "Temperature" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "TemperatureTooSmall",struct( ... + "Input",{{ "Temperature" -20 }},... + "Error","MATLAB:expectedNonnegative"),... + ... + "InvalidTopProbabilityMassType",struct( ... + "Input",{{ "TopProbabilityMass" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidTopProbabilityMassSize",struct( ... + "Input",{{ "TopProbabilityMass" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "TopProbabilityMassTooLarge",struct( ... + "Input",{{ "TopProbabilityMass" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "TopProbabilityMassTooSmall",struct( ... + "Input",{{ "TopProbabilityMass" -20 }},... + "Error","MATLAB:expectedNonnegative"),... + ... + "WrongTypeStopSequences",struct( ... + "Input",{{ "StopSequences" 123}},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "WrongSizeStopNonVector",struct( ... + "Input",{{ "StopSequences" repmat("stop", 4) }},... + "Error","MATLAB:validators:mustBeVector"),... + ... + "EmptyStopSequences",struct( ... + "Input",{{ "StopSequences" ""}},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "WrongSizeStopSequences",struct( ... + "Input",{{ "StopSequences" ["1" "2" "3" "4" "5"]}},... + "Error","llms:stopSequencesMustHaveMax4Elements"),... + ... + "InvalidPresencePenalty",struct( ... + "Input",{{ "PresencePenalty" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidPresencePenaltySize",struct( ... + "Input",{{ "PresencePenalty" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "PresencePenaltyTooLarge",struct( ... + "Input",{{ "PresencePenalty" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "PresencePenaltyTooSmall",struct( ... + "Input",{{ "PresencePenalty" -20 }},... + "Error","MATLAB:notGreaterEqual"),... + ... + "InvalidFrequencyPenalty",struct( ... + "Input",{{ "FrequencyPenalty" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidFrequencyPenaltySize",struct( ... + "Input",{{ "FrequencyPenalty" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "FrequencyPenaltyTooLarge",struct( ... + "Input",{{ "FrequencyPenalty" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "FrequencyPenaltyTooSmall",struct( ... + "Input",{{ "FrequencyPenalty" -20 }},... + "Error","MATLAB:notGreaterEqual"),... + ... + "InvalidApiKeyType",struct( ... + "Input",{{ "ApiKey" 123 }},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "InvalidApiKeySize",struct( ... + "Input",{{ "ApiKey" ["abc" "abc"] }},... + "Error","MATLAB:validators:mustBeTextScalar")); +end + +function invalidGenerateInput = iGetInvalidGenerateInput +emptyMessages = openAIMessages; +validMessages = addUserMessage(emptyMessages,"Who invented the telephone?"); + +invalidGenerateInput = struct( ... + "EmptyInput",struct( ... + "Input",{{ [] }},... + "Error","MATLAB:validation:IncompatibleSize"),... + ... + "InvalidInputType",struct( ... + "Input",{{ 123 }},... + "Error","llms:mustBeMessagesOrTxt"),... + ... + "EmptyMessages",struct( ... + "Input",{{ emptyMessages }},... + "Error","llms:mustHaveMessages"),... + ... + "InvalidMaxNumTokensType",struct( ... + "Input",{{ validMessages "MaxNumTokens" "2" }},... + "Error","MATLAB:validators:mustBeNumericOrLogical"),... + ... + "InvalidMaxNumTokensValue",struct( ... + "Input",{{ validMessages "MaxNumTokens" 0 }},... + "Error","MATLAB:validators:mustBePositive"),... + ... + "InvalidNumCompletionsType",struct( ... + "Input",{{ validMessages "NumCompletions" "2" }},... + "Error","MATLAB:validators:mustBeNumericOrLogical"),... + ... + "InvalidNumCompletionsValue",struct( ... + "Input",{{ validMessages "NumCompletions" 0 }},... + "Error","MATLAB:validators:mustBePositive"), ... + ... + "InvalidToolChoiceValue",struct( ... + "Input",{{ validMessages "ToolChoice" "functionDoesNotExist" }},... + "Error","MATLAB:validators:mustBeMember"),... + ... + "InvalidToolChoiceType",struct( ... + "Input",{{ validMessages "ToolChoice" 0 }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "InvalidToolChoiceSize",struct( ... + "Input",{{ validMessages "ToolChoice" ["validfunction", "validfunction"] }},... + "Error","MATLAB:validators:mustBeTextScalar")); +end \ No newline at end of file