diff --git a/+llms/+internal/callAzureChatAPI.m b/+llms/+internal/callAzureChatAPI.m index 907a5d1..532f9e8 100644 --- a/+llms/+internal/callAzureChatAPI.m +++ b/+llms/+internal/callAzureChatAPI.m @@ -73,8 +73,20 @@ if isempty(nvp.StreamFun) message = response.Body.Data.choices(1).message; else - message = struct("role", "assistant", ... - "content", streamedText); + pat = '{"' + wildcardPattern + '":'; + if contains(streamedText,pat) + s = jsondecode(streamedText); + if contains(s.function.arguments,pat) + prompt = jsondecode(s.function.arguments); + s.function.arguments = prompt; + end + message = struct("role", "assistant", ... + "content",[], ... + "tool_calls",jsondecode(streamedText)); + else + message = struct("role", "assistant", ... + "content", streamedText); + end end if isfield(message, "tool_choice") text = ""; diff --git a/+llms/+internal/callOllamaChatAPI.m b/+llms/+internal/callOllamaChatAPI.m index ce81780..7654677 100644 --- a/+llms/+internal/callOllamaChatAPI.m +++ b/+llms/+internal/callOllamaChatAPI.m @@ -84,6 +84,15 @@ parameters.stream = ~isempty(nvp.StreamFun); options = struct; + +if strcmp(nvp.ResponseFormat,"json") + parameters.format = struct('type','json_object'); +elseif isstruct(nvp.ResponseFormat) + parameters.format = llms.internal.jsonSchemaFromPrototype(nvp.ResponseFormat); +elseif startsWith(string(nvp.ResponseFormat), asManyOfPattern(whitespacePattern)+"{") + parameters.format = llms.internal.verbatimJSON(nvp.ResponseFormat); +end + if ~isempty(nvp.Seed) options.seed = nvp.Seed; end diff --git a/+llms/+internal/hasTools.m b/+llms/+internal/hasTools.m index d7a346c..cba95c7 100644 --- a/+llms/+internal/hasTools.m +++ b/+llms/+internal/hasTools.m @@ -12,4 +12,31 @@ Tools FunctionsStruct end + + methods(Hidden) + function mustBeValidFunctionCall(this, functionCall) + if ~isempty(functionCall) + mustBeTextScalar(functionCall); + if isempty(this.FunctionNames) + error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall")); + end + mustBeMember(functionCall, ["none","auto", this.FunctionNames]); + end + end + + function toolChoice = convertToolChoice(this, toolChoice) + % if toolChoice is empty + if isempty(toolChoice) + % if Tools is not empty, the default is 'auto'. + if ~isempty(this.Tools) + toolChoice = "auto"; + end + elseif ~ismember(toolChoice,["auto","none"]) + % if toolChoice is not empty, then it must be "auto", "none" or in the format + % {"type": "function", "function": {"name": "my_function"}} + toolChoice = struct("type","function","function",struct("name",toolChoice)); + end + + end + end end diff --git a/+llms/+internal/jsonSchemaFromPrototype.m b/+llms/+internal/jsonSchemaFromPrototype.m index f67943b..984849e 100644 --- a/+llms/+internal/jsonSchemaFromPrototype.m +++ b/+llms/+internal/jsonSchemaFromPrototype.m @@ -27,7 +27,7 @@ schema = struct("type","string"); elseif isinteger(prototype) schema = struct("type","integer"); - elseif isnumeric(prototype) + elseif isnumeric(prototype) && ~isa(prototype,'dlarray') schema = struct("type","number"); elseif islogical(prototype) schema = struct("type","boolean"); diff --git a/+llms/+internal/useSameFieldTypes.m b/+llms/+internal/useSameFieldTypes.m index 60856f5..97f11a4 100644 --- a/+llms/+internal/useSameFieldTypes.m +++ b/+llms/+internal/useSameFieldTypes.m @@ -21,7 +21,7 @@ case "struct" prototype = prototype(1); if isscalar(data) - if isequal(fieldnames(data),fieldnames(prototype)) + if isequal(sort(fieldnames(data)),sort(fieldnames(prototype))) for field_c = fieldnames(data).' field = field_c{1}; data.(field) = alignTypes(data.(field),prototype.(field)); diff --git a/+llms/+openai/validateResponseFormat.m b/+llms/+openai/validateResponseFormat.m index bf97168..ced401a 100644 --- a/+llms/+openai/validateResponseFormat.m +++ b/+llms/+openai/validateResponseFormat.m @@ -18,9 +18,6 @@ function validateResponseFormat(format,model,messages) error("llms:warningJsonInstruction", ... llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction")) end - else - warning("llms:warningJsonInstruction", ... - llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction")) end elseif requestsStructuredOutput(format) if ~startsWith(model,"gpt-4o") diff --git a/+llms/+utils/errorMessageCatalog.m b/+llms/+utils/errorMessageCatalog.m index 71235c7..967254e 100644 --- a/+llms/+utils/errorMessageCatalog.m +++ b/+llms/+utils/errorMessageCatalog.m @@ -28,6 +28,15 @@ end end end + + function s = createCatalog() + %createCatalog will run the initialization code and return the catalog + % This is only meant to get more correct test coverage reports: + % The test coverage reports do not include the properties initialization + % for Catalog from above, so we have a test seam here to re-run it + % within the framework, where it is reported. + s = buildErrorMessageCatalog; + end end end @@ -66,4 +75,5 @@ catalog("llms:stream:responseStreamer:InvalidInput") = "Input does not have the expected json format, got ""{1}""."; catalog("llms:unsupportedDatatypeInPrototype") = "Invalid data type ''{1}'' in prototype. Prototype must be a struct, composed of numerical, string, logical, categorical, or struct."; catalog("llms:incorrectResponseFormat") = "Invalid response format. Response format must be ""text"", ""json"", a struct, or a string with a JSON Schema definition."; +catalog("llms:OllamaStructuredOutputNeeds05") = "Structured output is not supported for Ollama version {1}. Use version 0.5.0 or newer."; end diff --git a/+llms/+utils/requestsStructuredOutput.m b/+llms/+utils/requestsStructuredOutput.m new file mode 100644 index 0000000..19c2362 --- /dev/null +++ b/+llms/+utils/requestsStructuredOutput.m @@ -0,0 +1,9 @@ +function tf = requestsStructuredOutput(format) +% This function is undocumented and will change in a future release + +% Simple function to check if requested format triggers structured output + +% Copyright 2024 The MathWorks, Inc. +tf = isstruct(format) || startsWith(format,asManyOfPattern(whitespacePattern)+"{"); +end + diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index cf5aa55..be3bff1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ # Code owners, to get auto-filled reviewer lists # To start with, we just assume everyone in the core team is included on all reviews -* @adulai @ccreutzi @debymf @MiriamScharnke @vpapanasta \ No newline at end of file +* @adulai @ccreutzi @debymf @MiriamScharnke @vpapanasta @emanuzzi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6be9f48..fcc3ccd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,6 @@ jobs: cache: true - name: Run tests and generate artifacts env: - OPENAI_KEY: ${{ secrets.OPENAI_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} AZURE_OPENAI_DEPLOYMENT: ${{ secrets.AZURE_DEPLOYMENT }} AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }} @@ -47,11 +46,4 @@ jobs: SECOND_OLLAMA_ENDPOINT: 127.0.0.1:11435 uses: matlab-actions/run-tests@v2 with: - test-results-junit: test-results/results.xml - code-coverage-cobertura: code-coverage/coverage.xml - source-folder: . - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - slug: matlab-deep-learning/llms-with-matlab + source-folder: . \ No newline at end of file diff --git a/azureChat.m b/azureChat.m index 51c82c8..66b15a4 100644 --- a/azureChat.m +++ b/azureChat.m @@ -103,9 +103,9 @@ function this = azureChat(systemPrompt, nvp) arguments systemPrompt {llms.utils.mustBeTextOrEmpty} = [] - nvp.Endpoint (1,1) string {mustBeNonzeroLengthTextScalar} - nvp.DeploymentID (1,1) string {mustBeNonzeroLengthTextScalar} - nvp.APIKey {mustBeNonzeroLengthTextScalar} + nvp.Endpoint (1,1) string {llms.utils.mustBeNonzeroLengthTextScalar} + nvp.DeploymentID (1,1) string {llms.utils.mustBeNonzeroLengthTextScalar} + nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar} nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty nvp.APIVersion (1,1) string {mustBeAPIVersion} = "2024-06-01" nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 @@ -222,7 +222,7 @@ nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences nvp.ResponseFormat {llms.utils.mustBeResponseFormat} = this.ResponseFormat - nvp.APIKey {mustBeNonzeroLengthTextScalar} = this.APIKey + nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar} = this.APIKey nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = this.PresencePenalty nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = this.FrequencyPenalty nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut @@ -289,31 +289,6 @@ end methods(Hidden) - function mustBeValidFunctionCall(this, functionCall) - if ~isempty(functionCall) - mustBeTextScalar(functionCall); - if isempty(this.FunctionNames) - error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall")); - end - mustBeMember(functionCall, ["none","auto", this.FunctionNames]); - end - end - - function toolChoice = convertToolChoice(this, toolChoice) - % if toolChoice is empty - if isempty(toolChoice) - % if Tools is not empty, the default is 'auto'. - if ~isempty(this.Tools) - toolChoice = "auto"; - end - elseif toolChoice ~= "auto" - % if toolChoice is not empty, then it must be in the format - % {"type": "function", "function": {"name": "my_function"}} - toolChoice = struct("type","function","function",struct("name",toolChoice)); - end - - end - function messageStruct = encodeImages(~, messageStruct) for k=1:numel(messageStruct) if isfield(messageStruct{k},"images") @@ -350,11 +325,6 @@ function mustBeValidFunctionCall(this, functionCall) end end -function mustBeNonzeroLengthTextScalar(content) -mustBeNonzeroLengthText(content) -mustBeTextScalar(content) -end - function [functionsStruct, functionNames] = functionAsStruct(functions) numFunctions = numel(functions); functionsStruct = cell(1, numFunctions); diff --git a/codecov.yml b/codecov.yml deleted file mode 100644 index 570b6df..0000000 --- a/codecov.yml +++ /dev/null @@ -1,3 +0,0 @@ -ignore: - - "tests" - - "**/errorMessageCatalog.m" diff --git a/doc/functions/ollamaChat.md b/doc/functions/ollamaChat.md index cb7ba51..6d05724 100644 --- a/doc/functions/ollamaChat.md +++ b/doc/functions/ollamaChat.md @@ -139,23 +139,40 @@ If the server does not respond within the timeout, then the function throws an e ### `ResponseFormat` — Response format -`"text"` (default) | `"json"` +`"text"` (default) | `"json"` | string scalar | structure array After construction, this property is read\-only. -Format of generated output. +Format of the `generatedOutput` output argument of the `generate` function. You can request unformatted output, JSON mode, or structured output. -If you set the response format to `"text"`, then the generated output is a string. +#### Unformatted Output -If you set the response format to `"json"`, then the generated output is a string containing JSON encoded data. +If you set the response format to `"text"`, then the generated output is an unformatted string. + + +#### JSON Mode + + +If you set the response format to `"json"`, then the generated output is a formatted string containing JSON encoded data. To configure the format of the generated JSON file, describe the format using natural language and provide it to the model either in the system prompt or as a user message. The prompt or message describing the format must contain the word `"json"` or `"JSON"`. +#### Structured Output + + +This option is only supported for Ollama version 0.5.0 and later. + + +To ensure that the model follows the required format, use structured output. To do this, set `ReponseFormat` to: + +- A string scalar containing a valid JSON Schema. +- A structure array containing an example that adheres to the required format, for example: `ResponseFormat=struct("Name","Rudolph","NoseColor",[255 0 0])` + # Other Properties ### `SystemPrompt` — System prompt diff --git a/examples/AnalyzeSentimentinTextUsingChatGPTinJSONMode.md b/examples/AnalyzeSentimentinTextUsingChatGPTinJSONMode.md deleted file mode 100644 index bef2356..0000000 --- a/examples/AnalyzeSentimentinTextUsingChatGPTinJSONMode.md +++ /dev/null @@ -1,88 +0,0 @@ - -# Analyze Sentiment in Text Using ChatGPT™ in JSON Mode - -To run the code shown on this page, open the MLX file in MATLAB®: [mlx-scripts/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mlx](mlx-scripts/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mlx) - -This example shows how to use ChatGPT for sentiment analysis and output the results in JSON format. - - -To run this example, you need a valid API key from a paid OpenAI™ API account. - -```matlab -loadenv(".env") -addpath('../..') -``` - -Define some text to analyze the sentiment. - -```matlab -inputText = ["I can't stand homework."; - "This sucks. I'm bored."; - "I can't wait for Halloween!!!"; - "I am neigher for or against the idea."; - "My cat is adorable ❤️❤️"; - "I hate chocolate"]; -``` - -Define the expected output JSON Schema. - -```matlab -jsonSchema = '{"sentiment": "string (positive, negative, neutral)","confidence_score": "number (0-1)"}'; -``` - -Define the system prompt, combining your instructions and the JSON Schema. In order for the model to output JSON, you need to specify that in the prompt, for example, adding *"designed to output to JSON"* to the prompt. - -```matlab -systemPrompt = "You are an AI designed to output to JSON. You analyze the sentiment of the provided text and " + ... - "Determine whether the sentiment is positive, negative, or neutral and provide a confidence score using " + ... - "the schema: " + jsonSchema; -prompt = "Analyze the sentiment of the provided text. " + ... - "Determine whether the sentiment is positive, negative," + ... - " or neutral and provide a confidence score"; -``` - -Create a chat object with `ModelName gpt-3.5-turbo` and specify `ResponseFormat` as `"json".` - -```matlab -model = "gpt-3.5-turbo"; -chat = openAIChat(systemPrompt, ModelName=model, ResponseFormat="json"); -``` - -```matlabTextOutput -Warning: When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message. -``` - -Concatenate the prompt and input text and generate an answer with the model. - -```matlab -scores = cell(1,numel(inputText)); -for i = 1:numel(inputText) -``` - -Generate a response from the message. - -```matlab - [json, message, response] = generate(chat,prompt + newline + newline + inputText(i)); - scores{i} = jsondecode(json); -end -``` - -Extract the data from the JSON ouput. - -```matlab -T = struct2table([scores{:}]); -T.text = inputText; -T = movevars(T,"text","Before","sentiment") -``` -| |text|sentiment|confidence_score| -|:--:|:--:|:--:|:--:| -|1|"I can't stand homework."|'negative'|0.9500| -|2|"This sucks. I'm bored."|'negative'|0.9000| -|3|"I can't wait for Halloween!!!"|'positive'|0.9500| -|4|"I am neigher for or against the idea."|'neutral'|1| -|5|"My cat is adorable ❤️❤️"|'positive'|0.9500| -|6|"I hate chocolate"|'negative'|0.9000| - - -*Copyright 2024 The MathWorks, Inc.* - diff --git a/examples/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.md b/examples/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.md new file mode 100644 index 0000000..31148fd --- /dev/null +++ b/examples/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.md @@ -0,0 +1,94 @@ + +# Analyze Sentiment in Text Using ChatGPT™ and Structured Output + +To run the code shown on this page, open the MLX file in MATLAB®: [mlx-scripts/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mlx](mlx-scripts/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mlx) + +This example shows how to use ChatGPT for sentiment analysis and output the results in a desired format. + + +To run this example, you need a valid API key from a paid OpenAI™ API account. + +```matlab +loadenv(".env") +addpath('../..') +``` + +Define some text to analyze the sentiment. + +```matlab +inputText = ["I can't stand homework."; + "This sucks. I'm bored."; + "I can't wait for Halloween!!!"; + "I am neither for nor against the idea."; + "My cat is adorable ❤️❤️"; + "I hate chocolate"; + "More work. Great."; + "More work. Great!"]; +``` + +Define the system prompt. + +```matlab +systemPrompt = "You are an AI designed to analyze the sentiment of the provided text and " + ... + "Determine whether the sentiment is positive, negative, or neutral " + ... + "and provide a confidence score between 0 and 1."; +prompt = "Analyze the sentiment of the provided text."; +``` + +Define the expected output format by providing an example – when supplied with a struct as the `ResponseFormat`, `generate` will return a struct with the same field names and data types. Use a [categorical](https://www.mathworks.com/help/matlab/categorical-arrays.html) to restrict the values that can be returned to the list `["positive","negative","neutral"]`. + +```matlab +prototype = struct(... + "sentiment", categorical("positive",["positive","negative","neutral"]),... + "confidence", 0.2) +``` + +```matlabTextOutput +prototype = struct with fields: + sentiment: positive + confidence: 0.2000 + +``` + +Create a chat object and set `ResponseFormat` to `prototype`. + +```matlab +chat = openAIChat(systemPrompt, ResponseFormat=prototype); +``` + +Concatenate the prompt and input text and generate an answer with the model. + +```matlab +scores = []; +for i = 1:numel(inputText) +``` + +Generate a response from the message. + +```matlab + thisResponse = generate(chat,prompt + newline + newline + inputText(i)); + scores = [scores; thisResponse]; %#ok +end +``` + +Extract the content from the output structure array `scores`. + +```matlab +T = struct2table(scores); +T.text = inputText; +T = movevars(T,"text","Before","sentiment") +``` +| |text|sentiment|confidence| +|:--:|:--:|:--:|:--:| +|1|"I can't stand homework."|negative|0.9500| +|2|"This sucks. I'm bored."|negative|0.9500| +|3|"I can't wait for Halloween!!!"|positive|0.9500| +|4|"I am neither for nor against the idea."|neutral|0.9500| +|5|"My cat is adorable ❤️❤️"|positive|0.9500| +|6|"I hate chocolate"|negative|0.9500| +|7|"More work. Great."|negative|0.8500| +|8|"More work. Great!"|positive|0.9000| + + +*Copyright 2024 The MathWorks, Inc.* + diff --git a/examples/mlx-scripts/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mlx b/examples/mlx-scripts/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mlx deleted file mode 100644 index c4581f3..0000000 Binary files a/examples/mlx-scripts/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mlx and /dev/null differ diff --git a/examples/mlx-scripts/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mlx b/examples/mlx-scripts/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mlx new file mode 100644 index 0000000..a093836 Binary files /dev/null and b/examples/mlx-scripts/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mlx differ diff --git a/messageHistory.m b/messageHistory.m index adac787..3700ae4 100644 --- a/messageHistory.m +++ b/messageHistory.m @@ -41,8 +41,8 @@ arguments this (1,1) messageHistory - name {mustBeNonzeroLengthTextScalar} - content {mustBeNonzeroLengthTextScalar} + name {llms.utils.mustBeNonzeroLengthTextScalar} + content {llms.utils.mustBeNonzeroLengthTextScalar} end newMessage = struct("role", "system", "name", string(name), "content", string(content)); @@ -64,7 +64,7 @@ arguments this (1,1) messageHistory - content {mustBeNonzeroLengthTextScalar} + content {llms.utils.mustBeNonzeroLengthTextScalar} end newMessage = struct("role", "user", "content", string(content)); @@ -106,7 +106,7 @@ arguments this (1,1) messageHistory - content {mustBeNonzeroLengthTextScalar} + content {llms.utils.mustBeNonzeroLengthTextScalar} images (1,:) {mustBeNonzeroLengthText} nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto" end @@ -133,9 +133,9 @@ arguments this (1,1) messageHistory - id {mustBeNonzeroLengthTextScalar} - name {mustBeNonzeroLengthTextScalar} - content {mustBeNonzeroLengthTextScalar} + id {llms.utils.mustBeNonzeroLengthTextScalar} + name {llms.utils.mustBeNonzeroLengthTextScalar} + content {llms.utils.mustBeNonzeroLengthTextScalar} end @@ -262,11 +262,6 @@ end end -function mustBeNonzeroLengthTextScalar(content) -mustBeNonzeroLengthText(content) -mustBeTextScalar(content) -end - function validateRegularAssistant(content) try mustBeNonzeroLengthText(content) diff --git a/ollamaChat.m b/ollamaChat.m index 08cf225..49068a4 100644 --- a/ollamaChat.m +++ b/ollamaChat.m @@ -173,10 +173,9 @@ % value is CHAT.StopSequences. % Example: ["The end.", "And that's all she wrote."] % - % - % ResponseFormat - The format of response the model returns. - % The default value is CHAT.ResponseFormat. - % "text" (default) | "json" + % ResponseFormat - The format of response the call returns. + % Default value is CHAT.ResponseFormat. + % "text" | "json" | struct | string with JSON Schema % % StreamFun - Function to callback when streaming the % result. The default value is CHAT.StreamFun. @@ -193,7 +192,7 @@ nvp.MinP {llms.utils.mustBeValidProbability} = this.MinP nvp.TopK (1,1) {mustBeReal,mustBePositive} = this.TopK nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences - nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat + nvp.ResponseFormat {llms.utils.mustBeResponseFormat} = this.ResponseFormat nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut nvp.TailFreeSamplingZ (1,1) {mustBeReal} = this.TailFreeSamplingZ nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} @@ -234,9 +233,16 @@ end if isfield(response.Body.Data,"error") + [versionStr, versionList] = serverVersion(nvp.Endpoint); + if llms.utils.requestsStructuredOutput(nvp.ResponseFormat) && ... + ~versionIsAtLeast(versionList, [0,5,0]) + error("llms:OllamaStructuredOutputNeeds05",llms.utils.errorMessageCatalog.getMessage("llms:OllamaStructuredOutputNeeds05", versionStr)); + end err = response.Body.Data.error; error("llms:apiReturnedError",llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedError",err)); end + + text = llms.internal.reformatOutput(text,nvp.ResponseFormat); end end @@ -310,3 +316,20 @@ function mustBeIntegerOrEmpty(value) mustBeInteger(value) end end + +function [versionStr, versionList] = serverVersion(endpoint) + URL = endpoint + "/api/version"; + if ~startsWith(URL,"http") + URL = "http://" + URL; + end + versionStr = webread(URL).version; + versionList = split(versionStr,'.'); + versionList = str2double(versionList); +end + +function tf = versionIsAtLeast(version,minVersion) + tf = version(1) > minVersion(1) || ... + (version(1) == minVersion(1) && (... + version(2) > minVersion(2) || ... + (version(2) == minVersion(2) && version(3) >= minVersion(3)))); +end diff --git a/openAIChat.m b/openAIChat.m index 0cb8a4e..dafeebe 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -98,7 +98,7 @@ nvp.TopP {llms.utils.mustBeValidProbability} = 1 nvp.StopSequences {llms.utils.mustBeValidStop} = {} nvp.ResponseFormat {llms.utils.mustBeResponseFormat} = "text" - nvp.APIKey {mustBeNonzeroLengthTextScalar} + nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar} nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0 nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 @@ -215,7 +215,7 @@ nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences nvp.ResponseFormat {llms.utils.mustBeResponseFormat} = this.ResponseFormat - nvp.APIKey {mustBeNonzeroLengthTextScalar} = this.APIKey + nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar} = this.APIKey nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = this.PresencePenalty nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = this.FrequencyPenalty nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut @@ -270,31 +270,6 @@ end methods(Hidden) - function mustBeValidFunctionCall(this, functionCall) - if ~isempty(functionCall) - mustBeTextScalar(functionCall); - if isempty(this.FunctionNames) - error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall")); - end - mustBeMember(functionCall, ["none","auto", this.FunctionNames]); - end - end - - function toolChoice = convertToolChoice(this, toolChoice) - % if toolChoice is empty - if isempty(toolChoice) - % if Tools is not empty, the default is 'auto'. - if ~isempty(this.Tools) - toolChoice = "auto"; - end - elseif ~ismember(toolChoice,["auto","none"]) - % if toolChoice is not empty, then it must be "auto", "none" or in the format - % {"type": "function", "function": {"name": "my_function"}} - toolChoice = struct("type","function","function",struct("name",toolChoice)); - end - - end - function messageStruct = encodeImages(~, messageStruct) for k=1:numel(messageStruct) if isfield(messageStruct{k},"images") @@ -331,11 +306,6 @@ function mustBeValidFunctionCall(this, functionCall) end end -function mustBeNonzeroLengthTextScalar(content) -mustBeNonzeroLengthText(content) -mustBeTextScalar(content) -end - function [functionsStruct, functionNames] = functionAsStruct(functions) numFunctions = numel(functions); functionsStruct = cell(1, numFunctions); diff --git a/openAIImages.m b/openAIImages.m index c963c38..bdbade5 100644 --- a/openAIImages.m +++ b/openAIImages.m @@ -42,7 +42,7 @@ function this = openAIImages(nvp) arguments nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["dall-e-2", "dall-e-3"])} = "dall-e-2" - nvp.APIKey {mustBeNonzeroLengthTextScalar} + nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar} nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 end @@ -82,7 +82,7 @@ arguments this (1,1) openAIImages - prompt {mustBeNonzeroLengthTextScalar} + prompt {llms.utils.mustBeNonzeroLengthTextScalar} nvp.NumImages (1,1) {mustBePositive, mustBeInteger,... mustBeLessThanOrEqual(nvp.NumImages,10)} = 1 nvp.Size (1,1) string {mustBeMember(nvp.Size, ["256x256", "512x512", ... @@ -172,7 +172,7 @@ arguments this (1,1) openAIImages imagePath {mustBeValidFileType(imagePath)} - prompt {mustBeNonzeroLengthTextScalar} + prompt {llms.utils.mustBeNonzeroLengthTextScalar} nvp.MaskImagePath {mustBeValidFileType(nvp.MaskImagePath)} nvp.NumImages (1,1) {mustBePositive, mustBeInteger,... mustBeLessThanOrEqual(nvp.NumImages,10)} = 1 @@ -337,11 +337,6 @@ function mustBeValidFileType(filePath) mustBeLessThan(s.bytes,4e+6) end -function mustBeNonzeroLengthTextScalar(content) -mustBeNonzeroLengthText(content) -mustBeTextScalar(content) -end - function data = myImread(URI) % imread usually, but not always, fails to read from the % https://oaidalleapiprodscus.blob.core.windows.net URLs returned by diff --git a/tests/hopenAIChat.m b/tests/hopenAIChat.m index eabb603..9c052e8 100644 --- a/tests/hopenAIChat.m +++ b/tests/hopenAIChat.m @@ -1,4 +1,4 @@ -classdef (Abstract) hopenAIChat < matlab.unittest.TestCase +classdef (Abstract) hopenAIChat < hstructuredOutput % Tests for OpenAI-based chats (openAIChat, azureChat) % Copyright 2023-2024 The MathWorks, Inc. @@ -17,8 +17,6 @@ constructor defaultModel visionModel - structuredModel - noStructuredOutputModel end methods(Test) @@ -195,66 +193,6 @@ function generateOverridesProperties(testCase) testCase.verifyThat(text, EndsWithSubstring("3, ")); end - function generateWithStructuredOutput(testCase) - import matlab.unittest.constraints.IsEqualTo - import matlab.unittest.constraints.StartsWithSubstring - res = generate(testCase.structuredModel,"Which animal produces honey?",... - ResponseFormat = struct(commonName = "dog", scientificName = "Canis familiaris")); - testCase.assertClass(res,"struct"); - testCase.verifySize(fieldnames(res),[2,1]); - testCase.verifyThat(res.commonName, IsEqualTo("Honeybee") | IsEqualTo("Honey bee") | IsEqualTo("Honey Bee")); - testCase.verifyThat(res.scientificName, StartsWithSubstring("Apis")); - end - - function generateListWithStructuredOutput(testCase) - prototype = struct("plantName",{"appletree","pear"}, ... - "fruit",{"apple","pear"}, ... - "edible",[true,true], ... - "ignore", missing); - res = generate(testCase.structuredModel,"What is harvested in August?", ResponseFormat = prototype); - testCase.verifyCompatibleStructs(res, prototype); - end - - function generateWithNestedStructs(testCase) - stepsPrototype = struct("explanation",{"a","b"},"assumptions",{"a","b"}); - prototype = struct("steps",stepsPrototype,"final_answer","a"); - res = generate(testCase.structuredModel,"What is the positive root of x^2-2*x+1?", ... - ResponseFormat=prototype); - testCase.verifyCompatibleStructs(res,prototype); - end - - function incompleteJSONResponse(testCase) - country = ["USA";"UK"]; - capital = ["Washington, D.C.";"London"]; - population = [345716792;69203012]; - prototype = struct("country",country,"capital",capital,"population",population); - - testCase.verifyError(@() generate(testCase.structuredModel, ... - "What are the five largest countries whose English names" + ... - " start with the letter A?", ... - ResponseFormat = prototype, MaxNumTokens=3), "llms:apiReturnedIncompleteJSON"); - end - - function generateWithExplicitSchema(testCase) - import matlab.unittest.constraints.IsSameSetAs - schema = iGetSchema(); - - genUser = generate(testCase.structuredModel,"Create a sample user",ResponseFormat=schema); - genAddress = generate(testCase.structuredModel,"Create a sample address",ResponseFormat=schema); - - testCase.verifyClass(genUser,"string"); - genUserDecoded = jsondecode(genUser); - testCase.verifyClass(genUserDecoded.item,"struct"); - testCase.verifyThat(fieldnames(genUserDecoded.item),... - IsSameSetAs({'name','age'})); - - testCase.verifyClass(genAddress,"string"); - genAddressDecoded = jsondecode(genAddress); - testCase.verifyClass(genAddressDecoded.item,"struct"); - testCase.verifyThat(fieldnames(genAddressDecoded.item),... - IsSameSetAs({'number','street','city'})); - end - function invalidInputsGenerate(testCase, InvalidGenerateInput) f = openAIFunction("validfunction"); chat = testCase.constructor(Tools=f, APIKey="this-is-not-a-real-key"); @@ -297,14 +235,6 @@ function createChatWithStreamFunc(testCase) testCase.verifyGreaterThan(numel(sf("")), 1); end - function warningJSONResponseFormat(testCase) - chat = @() testCase.constructor("You are a useful assistant", ... - APIKey="this-is-not-a-real-key", ... - ResponseFormat="json"); - - testCase.verifyWarning(@()chat(), "llms:warningJsonInstruction"); - end - function errorJSONResponseFormat(testCase) testCase.verifyError( ... @() generate(testCase.structuredModel,"create some address",ResponseFormat="json"), ... @@ -329,89 +259,4 @@ function keyNotFound(testCase) testCase.verifyError(testCase.constructor, "llms:keyMustBeSpecified"); end end - - methods - function verifyCompatibleStructs(testCase,data,prototype) - import matlab.unittest.constraints.IsSameSetAs - testCase.assertClass(data,"struct"); - if ~isscalar(data) - arrayfun(@(d) testCase.verifyCompatibleStructs(d,prototype), data); - return - end - testCase.assertClass(prototype,"struct"); - if ~isscalar(prototype) - prototype = prototype(1); - end - testCase.assertThat(fieldnames(data),IsSameSetAs(fieldnames(prototype))); - for name = fieldnames(data).' - field = name{1}; - testCase.verifyClass(data.(field),class(prototype.(field))); - if isstruct(data.(field)) - testCase.verifyCompatibleStructs(data.(field),prototype.(field)); - end - end - end - end -end - -function str = iGetSchema() -% an example from https://platform.openai.com/docs/guides/structured-outputs/supported-schemas -str = string(join({ - '{' - ' "type": "object",' - ' "properties": {' - ' "item": {' - ' "anyOf": [' - ' {' - ' "type": "object",' - ' "description": "The user object to insert into the database",' - ' "properties": {' - ' "name": {' - ' "type": "string",' - ' "description": "The name of the user"' - ' },' - ' "age": {' - ' "type": "number",' - ' "description": "The age of the user"' - ' }' - ' },' - ' "additionalProperties": false,' - ' "required": [' - ' "name",' - ' "age"' - ' ]' - ' },' - ' {' - ' "type": "object",' - ' "description": "The address object to insert into the database",' - ' "properties": {' - ' "number": {' - ' "type": "string",' - ' "description": "The number of the address. Eg. for 123 main st, this would be 123"' - ' },' - ' "street": {' - ' "type": "string",' - ' "description": "The street name. Eg. for 123 main st, this would be main st"' - ' },' - ' "city": {' - ' "type": "string",' - ' "description": "The city of the address"' - ' }' - ' },' - ' "additionalProperties": false,' - ' "required": [' - ' "number",' - ' "street",' - ' "city"' - ' ]' - ' }' - ' ]' - ' }' - ' },' - ' "additionalProperties": false,' - ' "required": [' - ' "item"' - ' ]' - '}' -}, newline)); end diff --git a/tests/hstructuredOutput.m b/tests/hstructuredOutput.m new file mode 100644 index 0000000..445a7ab --- /dev/null +++ b/tests/hstructuredOutput.m @@ -0,0 +1,150 @@ +classdef (Abstract) hstructuredOutput < matlab.unittest.TestCase +% Tests for completion APIs providing structured output + +% Copyright 2023-2024 The MathWorks, Inc. + + properties(Abstract) + structuredModel + end + + methods(Test) + % Test methods + function generateWithStructuredOutput(testCase) + import matlab.unittest.constraints.IsEqualTo + import matlab.unittest.constraints.StartsWithSubstring + res = generate(testCase.structuredModel,"Which animal produces honey?",... + ResponseFormat = struct(commonName = "dog", scientificName = "Canis familiaris")); + testCase.assertClass(res,"struct"); + testCase.verifySize(fieldnames(res),[2,1]); + testCase.verifyThat(lower(res.commonName), IsEqualTo("honeybee") | IsEqualTo("honey bee") | ... + IsEqualTo("bee")); + testCase.verifyThat(res.scientificName, StartsWithSubstring("Apis")); + end + + function generateListWithStructuredOutput(testCase) + prototype = struct("plantName",{"appletree","pear"}, ... + "fruit",{"apple","pear"}, ... + "edible",[true,true], ... + "ignore", missing); + res = generate(testCase.structuredModel,"What is harvested in August?", ResponseFormat = prototype); + testCase.verifyCompatibleStructs(res, prototype); + end + + function generateWithNestedStructs(testCase) + stepsPrototype = struct("explanation",{"a","b"},"assumptions",{"a","b"}); + prototype = struct("steps",stepsPrototype,"final_answer","a"); + res = generate(testCase.structuredModel,"What is the positive root of x^2-2*x+1?", ... + ResponseFormat=prototype); + testCase.verifyCompatibleStructs(res,prototype); + end + + function incompleteJSONResponse(testCase) + country = ["USA";"UK"]; + capital = ["Washington, D.C.";"London"]; + population = [345716792;69203012]; + prototype = struct("country",country,"capital",capital,"population",population); + + testCase.verifyError(@() generate(testCase.structuredModel, ... + "What are the five largest countries whose English names" + ... + " start with the letter A?", ... + ResponseFormat = prototype, MaxNumTokens=3), "llms:apiReturnedIncompleteJSON"); + end + + function generateWithExplicitSchema(testCase) + import matlab.unittest.constraints.IsSameSetAs + schema = iGetSchema(); + + genUser = generate(testCase.structuredModel,"Create a sample user",ResponseFormat=schema); + + testCase.verifyClass(genUser,"string"); + genUserDecoded = jsondecode(genUser); + testCase.verifyClass(genUserDecoded.item,"struct"); + testCase.verifyThat(fieldnames(genUserDecoded.item),... + IsSameSetAs({'name','age'}) | IsSameSetAs({'number','street','city'})); + end + end + + methods + function verifyCompatibleStructs(testCase,data,prototype) + testCase.assertClass(data,"struct"); + testCase.assertClass(prototype,"struct"); + arrayfun(@(d) testCase.verifyCompatibleStructsScalar(d,prototype(1)), data); + end + + function verifyCompatibleStructsScalar(testCase,data,prototype) + import matlab.unittest.constraints.IsSameSetAs + testCase.assertClass(data,"struct"); + testCase.assertClass(prototype,"struct"); + testCase.assertThat(fieldnames(data),IsSameSetAs(fieldnames(prototype))); + for name = fieldnames(data).' + field = name{1}; + testCase.verifyClass(data.(field),class(prototype.(field))); + if isstruct(data.(field)) + testCase.verifyCompatibleStructs(data.(field),prototype.(field)); + end + end + end + end +end + +function str = iGetSchema() +% an example from https://platform.openai.com/docs/guides/structured-outputs/supported-schemas +str = string(join({ + '{' + ' "type": "object",' + ' "properties": {' + ' "item": {' + ' "anyOf": [' + ' {' + ' "type": "object",' + ' "description": "The user object to insert into the database",' + ' "properties": {' + ' "name": {' + ' "type": "string",' + ' "description": "The name of the user"' + ' },' + ' "age": {' + ' "type": "number",' + ' "description": "The age of the user"' + ' }' + ' },' + ' "additionalProperties": false,' + ' "required": [' + ' "name",' + ' "age"' + ' ]' + ' },' + ' {' + ' "type": "object",' + ' "description": "The address object to insert into the database",' + ' "properties": {' + ' "number": {' + ' "type": "string",' + ' "description": "The number of the address. Eg. for 123 main st, this would be 123"' + ' },' + ' "street": {' + ' "type": "string",' + ' "description": "The street name. Eg. for 123 main st, this would be main st"' + ' },' + ' "city": {' + ' "type": "string",' + ' "description": "The city of the address"' + ' }' + ' },' + ' "additionalProperties": false,' + ' "required": [' + ' "number",' + ' "street",' + ' "city"' + ' ]' + ' }' + ' ]' + ' }' + ' },' + ' "additionalProperties": false,' + ' "required": [' + ' "item"' + ' ]' + '}' +}, newline)); +end diff --git a/tests/recordings/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mat b/tests/recordings/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mat deleted file mode 100644 index 310d0e8..0000000 Binary files a/tests/recordings/AnalyzeSentimentinTextUsingChatGPTinJSONMode.mat and /dev/null differ diff --git a/tests/recordings/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mat b/tests/recordings/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mat new file mode 100644 index 0000000..21b5d7c Binary files /dev/null and b/tests/recordings/AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput.mat differ diff --git a/tests/tazureChat.m b/tests/tazureChat.m index 6e58a64..b172d78 100644 --- a/tests/tazureChat.m +++ b/tests/tazureChat.m @@ -17,7 +17,6 @@ visionModel = azureChat(Deployment="gpt-4o"); structuredModel = azureChat("APIVersion","2024-08-01-preview",... "Deployment","gpt-4o-2024-08-06"); - noStructuredOutputModel = azureChat(APIVersion="2024-08-01-preview"); end methods(Test) @@ -171,18 +170,6 @@ function deploymentNotFound(testCase) testCase.verifyError(@()azureChat, "llms:deploymentMustBeSpecified"); end - % open TODOs for azureChat - function settingToolChoiceWithNone(testCase) - testCase.assumeFail("azureChat need different handling of ToolChoice 'none'"); - end - - function generateWithToolsAndStreamFunc(testCase) - testCase.assumeFail("need to make azureChat return tool_call in the same way as openAIChat"); - end - - function warningJSONResponseFormat(testCase) - testCase.assumeFail("TODO for azureChat"); - end end end @@ -202,6 +189,21 @@ function warningJSONResponseFormat(testCase) "SystemPrompt", {[]}, ... "ResponseFormat", {"text"} ... ) ... + ),... + "SomeSettings", struct( ... + "Input",{{"Temperature",1.23,"TopP",0.6,"TimeOut",120,"ResponseFormat","json"}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1.23}, ... + "TopP", {0.6}, ... + "StopSequences", {string([])}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {120}, ... + "FunctionNames", {[]}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"json"} ... + ) ... )); end diff --git a/tests/terrorMessageCatalog.m b/tests/terrorMessageCatalog.m new file mode 100644 index 0000000..7c63baa --- /dev/null +++ b/tests/terrorMessageCatalog.m @@ -0,0 +1,27 @@ +classdef terrorMessageCatalog < matlab.unittest.TestCase +% Tests for errorMessageCatalog + +% Copyright 2024 The MathWorks, Inc. + + methods(Test) + function ensureCorrectCoverage(testCase) + testCase.verifyClass( ... + llms.utils.errorMessageCatalog.createCatalog,"dictionary"); + end + + function holeValuesAreUsed(testCase) + import matlab.unittest.constraints.IsEqualTo + + % we do not check the whole string, because error message + % text *should* be able to change without test points changing. + % That is necessary to enable localization. + messageID = "llms:mustBeValidIndex"; + + message1 = llms.utils.errorMessageCatalog.getMessage(messageID, "input1"); + message2 = llms.utils.errorMessageCatalog.getMessage(messageID, "input2"); + + testCase.verifyThat(message1, ~IsEqualTo(message2)); + testCase.verifyThat(replace(message1, "input1", "input2"), IsEqualTo(message2)); + end + end +end diff --git a/tests/texampleTests.m b/tests/texampleTests.m index e08a5c9..8322550 100644 --- a/tests/texampleTests.m +++ b/tests/texampleTests.m @@ -30,7 +30,7 @@ function setUpAndTearDowns(testCase) import matlab.unittest.fixtures.CurrentFolderFixture testCase.applyFixture(CurrentFolderFixture("../examples/mlx-scripts")); - openAIEnvVar = "OPENAI_KEY"; + openAIEnvVar = "OPENAI_API_KEY"; secretKey = getenv(openAIEnvVar); % Create an empty .env file because it is expected by our .mlx % example files @@ -65,10 +65,9 @@ function testAnalyzeScientificPapersUsingFunctionCalls(testCase) AnalyzeScientificPapersUsingFunctionCalls; end - function testAnalyzeSentimentinTextUsingChatGPTinJSONMode(testCase) - testCase.startCapture("AnalyzeSentimentinTextUsingChatGPTinJSONMode"); - testCase.verifyWarning(@AnalyzeSentimentinTextUsingChatGPTinJSONMode,... - "llms:warningJsonInstruction"); + function testAnalyzeSentimentinTextUsingChatGPTwithStructuredOutput(testCase) + testCase.startCapture("AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput"); + AnalyzeSentimentinTextUsingChatGPTwithStructuredOutput; end function testAnalyzeTextDataUsingParallelFunctionCallwithChatGPT(testCase) diff --git a/tests/textractOpenAIEmbeddings.m b/tests/textractOpenAIEmbeddings.m index f5352aa..da889c5 100644 --- a/tests/textractOpenAIEmbeddings.m +++ b/tests/textractOpenAIEmbeddings.m @@ -3,18 +3,6 @@ % Copyright 2023-2024 The MathWorks, Inc. - methods (TestClassSetup) - function saveEnvVar(testCase) - % Ensures key is not in environment variable for tests - openAIEnvVar = "OPENAI_API_KEY"; - if isenv(openAIEnvVar) - key = getenv(openAIEnvVar); - unsetenv(openAIEnvVar); - testCase.addTeardown(@(x) setenv(openAIEnvVar, x), key); - end - end - end - properties(TestParameter) InvalidInput = iGetInvalidInput(); ValidDimensionsModelCombinations = iGetValidDimensionsModelCombinations(); @@ -29,6 +17,13 @@ function embedsDifferentStringTypes(testCase) end function keyNotFound(testCase) + % Ensures key is not in environment variable for tests + openAIEnvVar = "OPENAI_API_KEY"; + if isenv(openAIEnvVar) + key = getenv(openAIEnvVar); + reset = onCleanup(@() setenv(openAIEnvVar, key)); + unsetenv(openAIEnvVar); + end testCase.verifyError(@()extractOpenAIEmbeddings("bla"), "llms:keyMustBeSpecified"); end @@ -40,8 +35,7 @@ function validCombinationOfModelAndDimension(testCase, ValidDimensionsModelCombi end function embedStringWithSuccessfulOpenAICall(testCase) - testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ... - APIKey=getenv("OPENAI_KEY"))); + testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla")); end function invalidCombinationOfModelAndDimension(testCase) diff --git a/tests/tollamaChat.m b/tests/tollamaChat.m index 8d863c7..e68684f 100644 --- a/tests/tollamaChat.m +++ b/tests/tollamaChat.m @@ -1,4 +1,4 @@ -classdef tollamaChat < matlab.unittest.TestCase +classdef tollamaChat < hstructuredOutput % Tests for ollamaChat % Copyright 2024 The MathWorks, Inc. @@ -11,6 +11,10 @@ StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}}); end + properties + structuredModel = ollamaChat("mistral"); + end + methods(Test) function simpleConstruction(testCase) bot = ollamaChat("mistral"); @@ -122,6 +126,8 @@ function seedFixesResult(testCase) end function generateWithImages(testCase) + testCase.assumeFail("CI only assertion failure within Ollama, as of 2024-10-30, Ollama 0.3.14"); + import matlab.unittest.constraints.ContainsSubstring chat = ollamaChat("moondream"); image_path = "peppers.png"; diff --git a/tests/topenAIChat.m b/tests/topenAIChat.m index b528378..a79f572 100644 --- a/tests/topenAIChat.m +++ b/tests/topenAIChat.m @@ -340,7 +340,7 @@ function specialErrorForUnsupportedResponseFormat(testCase) ), ... "ResponseFormat", struct( ... "Input",{{"APIKey","this-is-not-a-real-key","ResponseFormat","json"}}, ... - "ExpectedWarning", "llms:warningJsonInstruction", ... + "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... "TopP", {1}, ... diff --git a/tests/topenAIImages.m b/tests/topenAIImages.m index 4ec024e..3d586fd 100644 --- a/tests/topenAIImages.m +++ b/tests/topenAIImages.m @@ -3,18 +3,6 @@ % Copyright 2024 The MathWorks, Inc. - methods (TestClassSetup) - function saveEnvVar(testCase) - % Ensures key is not in environment variable for tests - openAIEnvVar = "OPENAI_API_KEY"; - if isenv(openAIEnvVar) - key = getenv(openAIEnvVar); - testCase.addTeardown(@() setenv(openAIEnvVar, key)); - unsetenv(openAIEnvVar); - end - end - end - properties(TestParameter) InvalidConstructorInput = iGetInvalidConstructorInput; InvalidGenerateInput = iGetInvalidGenerateInput; @@ -32,6 +20,13 @@ function generateAcceptsSingleStringAsInput(testCase) end function keyNotFound(testCase) + % Ensures key is not in environment variable for tests + openAIEnvVar = "OPENAI_API_KEY"; + if isenv(openAIEnvVar) + key = getenv(openAIEnvVar); + reset = onCleanup(@() setenv(openAIEnvVar, key)); + unsetenv(openAIEnvVar); + end testCase.verifyError(@()openAIImages, "llms:keyMustBeSpecified"); end @@ -131,7 +126,7 @@ function invalidInputsVariation(testCase, InvalidVariationInput) end function testThatImageIsReturned(testCase) - mdl = openAIImages(APIKey=getenv("OPENAI_KEY")); + mdl = openAIImages; [images, response] = generate(mdl, ... "Create a 3D avatar of a whimsical sushi on the beach. " + ...