diff --git a/extractOpenAIEmbeddings.m b/extractOpenAIEmbeddings.m index 4bec7bb..9660052 100644 --- a/extractOpenAIEmbeddings.m +++ b/extractOpenAIEmbeddings.m @@ -14,16 +14,21 @@ % % 'TimeOut' - Connection Timeout in seconds (default: 10 secs) % +% 'Dimensions' - Number of dimensions the resulting output +% embeddings should have. +% % [emb, response] = EXTRACTOPENAIEMBEDDINGS(...) also returns the full % response from the OpenAI API call. % % Copyright 2023 The MathWorks, Inc. arguments - text (1,:) {mustBeText} - nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,"text-embedding-ada-002")} = "text-embedding-ada-002" - nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 - nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar} + text (1,:) {mustBeText} + nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ... + "text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002" + nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 + nvp.Dimensions (1,1) {mustBeInteger} + nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar} end END_POINT = "https://api.openai.com/v1/embeddings"; @@ -32,6 +37,15 @@ parameters = struct("input",text,"model",nvp.ModelName); +if isfield(nvp, "Dimensions") + if nvp.ModelName=="text-embedding-ada-002" + error("llms:invalidOptionForModel", ... + llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Dimensions", nvp.ModelName)); + end + parameters.dimensions = nvp.Dimensions; +end + + response = llms.internal.sendRequest(parameters,key, END_POINT, nvp.TimeOut); if isfield(response.Body.Data, "data") diff --git a/openAIChat.m b/openAIChat.m index ed91bc1..fc15fbb 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -117,7 +117,7 @@ nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", ... "gpt-3.5-turbo", "gpt-3.5-turbo-16k",... "gpt-4-1106-preview","gpt-3.5-turbo-1106", ... - "gpt-4-vision-preview"])} = "gpt-3.5-turbo" + "gpt-4-vision-preview", "gpt-4-turbo-preview"])} = "gpt-3.5-turbo" nvp.Temperature {mustBeValidTemperature} = 1 nvp.TopProbabilityMass {mustBeValidTopP} = 1 nvp.StopSequences {mustBeValidStop} = {} @@ -132,7 +132,7 @@ if isfield(nvp,"StreamFun") this.StreamFun = nvp.StreamFun; if strcmp(nvp.ModelName,'gpt-4-vision-preview') - error("llms:invalidOptionAndValueForModel", ... + error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StreamFun", nvp.ModelName)); end else @@ -147,7 +147,7 @@ this.Tools = nvp.Tools; [this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools); if strcmp(nvp.ModelName,'gpt-4-vision-preview') - error("llms:invalidOptionAndValueForModel", ... + error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Tools", nvp.ModelName)); end end @@ -164,7 +164,7 @@ this.TopProbabilityMass = nvp.TopProbabilityMass; this.StopSequences = nvp.StopSequences; if ~isempty(nvp.StopSequences) && strcmp(nvp.ModelName,'gpt-4-vision-preview') - error("llms:invalidOptionAndValueForModel", ... + error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StopSequences", nvp.ModelName)); end @@ -222,13 +222,13 @@ end if nvp.MaxNumTokens ~= Inf && strcmp(this.ModelName,'gpt-4-vision-preview') - error("llms:invalidOptionAndValueForModel", ... + error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "MaxNumTokens", this.ModelName)); end toolChoice = convertToolChoice(this, nvp.ToolChoice); if ~isempty(nvp.ToolChoice) && strcmp(this.ModelName,'gpt-4-vision-preview') - error("llms:invalidOptionAndValueForModel", ... + error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "ToolChoice", this.ModelName)); end diff --git a/tests/textractOpenAIEmbeddings.m b/tests/textractOpenAIEmbeddings.m index 4791e01..a58c1f6 100644 --- a/tests/textractOpenAIEmbeddings.m +++ b/tests/textractOpenAIEmbeddings.m @@ -31,6 +31,14 @@ function keyNotFound(testCase) testCase.verifyError(@()extractOpenAIEmbeddings("bla"), "llms:keyMustBeSpecified"); end + function invalidCombinationOfModelAndDimension(testCase) + testCase.verifyError(@()extractOpenAIEmbeddings("bla", ... + Dimensions=10,... + ModelName="text-embedding-ada-002", ... + ApiKey="not-real"), ... + "llms:invalidOptionForModel") + end + function useAllNVP(testCase) testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ModelName="text-embedding-ada-002", ... ApiKey="this-is-not-a-real-key", TimeOut=10)); @@ -72,6 +80,14 @@ function testInvalidInputs(testCase, InvalidInput) "Input",{{"bla", "ModelName", "gpt" }},... "Error","MATLAB:validators:mustBeMember"),... ... + "InvalidDimensionType",struct( ... + "Input",{{"bla", "Dimensions", "123" }},... + "Error","MATLAB:validators:mustBeNumericOrLogical"),... + ... + "InvalidDimensionSize",struct( ... + "Input",{{"bla", "Dimensions", [123, 123] }},... + "Error","MATLAB:validation:IncompatibleSize"),... + ... "InvalidApiKeyType",struct( ... "Input",{{"bla", "ApiKey" 123 }},... "Error","MATLAB:validators:mustBeNonzeroLengthText"),...