Skip to content

Commit

Permalink
adding new models.
Browse files Browse the repository at this point in the history
  • Loading branch information
debymf committed Jan 27, 2024
1 parent 7e789c3 commit a038630
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
22 changes: 18 additions & 4 deletions extractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@
%
% 'TimeOut' - Connection Timeout in seconds (default: 10 secs)
%
% 'Dimensions' - Number of dimensions the resulting output
% embeddings should have.
%
% [emb, response] = EXTRACTOPENAIEMBEDDINGS(...) also returns the full
% response from the OpenAI API call.
%
% Copyright 2023 The MathWorks, Inc.

arguments
text (1,:) {mustBeText}
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,"text-embedding-ada-002")} = "text-embedding-ada-002"
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
text (1,:) {mustBeText}
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
nvp.Dimensions (1,1) {mustBeInteger}
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
end

END_POINT = "https://api.openai.com/v1/embeddings";
Expand All @@ -32,6 +37,15 @@

parameters = struct("input",text,"model",nvp.ModelName);

if isfield(nvp, "Dimensions")
if nvp.ModelName=="text-embedding-ada-002"
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Dimensions", nvp.ModelName));
end
parameters.dimensions = nvp.Dimensions;
end


response = llms.internal.sendRequest(parameters,key, END_POINT, nvp.TimeOut);

if isfield(response.Body.Data, "data")
Expand Down
12 changes: 6 additions & 6 deletions openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", ...
"gpt-3.5-turbo", "gpt-3.5-turbo-16k",...
"gpt-4-1106-preview","gpt-3.5-turbo-1106", ...
"gpt-4-vision-preview"])} = "gpt-3.5-turbo"
"gpt-4-vision-preview", "gpt-4-turbo-preview"])} = "gpt-3.5-turbo"
nvp.Temperature {mustBeValidTemperature} = 1
nvp.TopProbabilityMass {mustBeValidTopP} = 1
nvp.StopSequences {mustBeValidStop} = {}
Expand All @@ -132,7 +132,7 @@
if isfield(nvp,"StreamFun")
this.StreamFun = nvp.StreamFun;
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StreamFun", nvp.ModelName));
end
else
Expand All @@ -147,7 +147,7 @@
this.Tools = nvp.Tools;
[this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools);
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Tools", nvp.ModelName));
end
end
Expand All @@ -164,7 +164,7 @@
this.TopProbabilityMass = nvp.TopProbabilityMass;
this.StopSequences = nvp.StopSequences;
if ~isempty(nvp.StopSequences) && strcmp(nvp.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StopSequences", nvp.ModelName));
end

Expand Down Expand Up @@ -222,13 +222,13 @@
end

if nvp.MaxNumTokens ~= Inf && strcmp(this.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "MaxNumTokens", this.ModelName));
end

toolChoice = convertToolChoice(this, nvp.ToolChoice);
if ~isempty(nvp.ToolChoice) && strcmp(this.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "ToolChoice", this.ModelName));
end

Expand Down
16 changes: 16 additions & 0 deletions tests/textractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ function keyNotFound(testCase)
testCase.verifyError(@()extractOpenAIEmbeddings("bla"), "llms:keyMustBeSpecified");
end

function invalidCombinationOfModelAndDimension(testCase)
testCase.verifyError(@()extractOpenAIEmbeddings("bla", ...
Dimensions=10,...
ModelName="text-embedding-ada-002", ...
ApiKey="not-real"), ...
"llms:invalidOptionForModel")
end

function useAllNVP(testCase)
testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ModelName="text-embedding-ada-002", ...
ApiKey="this-is-not-a-real-key", TimeOut=10));
Expand Down Expand Up @@ -72,6 +80,14 @@ function testInvalidInputs(testCase, InvalidInput)
"Input",{{"bla", "ModelName", "gpt" }},...
"Error","MATLAB:validators:mustBeMember"),...
...
"InvalidDimensionType",struct( ...
"Input",{{"bla", "Dimensions", "123" }},...
"Error","MATLAB:validators:mustBeNumericOrLogical"),...
...
"InvalidDimensionSize",struct( ...
"Input",{{"bla", "Dimensions", [123, 123] }},...
"Error","MATLAB:validation:IncompatibleSize"),...
...
"InvalidApiKeyType",struct( ...
"Input",{{"bla", "ApiKey" 123 }},...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
Expand Down

0 comments on commit a038630

Please sign in to comment.