diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index f37cee5280c9b299e17caded29c673e003477dcf..9df357ac878adbe839f914d91acb0e950a1cf4e3 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -619,6 +619,10 @@ func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAi openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ OfAuto: param.NewOpt("none"), } + case fantasy.ToolChoiceRequired: + openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: param.NewOpt("required"), + } default: openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{ diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 97296fd974d9cb069649e3dcdb167e88994dfe0b..97e5a8e9fe97e452749411ca8cc635db221be503 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -1288,6 +1288,65 @@ func TestDoGenerate(t *testing.T) { require.Equal(t, `{"value":"Spark"}`, toolCall.Input) }) + t.Run("should handle ToolChoiceRequired", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + + server.prepareJSONResponse(map[string]any{ + "content": "", + }) + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo") + + _, err = model.Generate(context.Background(), fantasy.Call{ + Prompt: testPrompt, + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ + Name: "test-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + }, + }, + "required": []string{"value"}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#", + }, + }, + }, + ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0], + }) + + require.NoError(t, err) + require.Len(t, server.calls, 1) + + call := server.calls[0] + require.Equal(t, "gpt-3.5-turbo", call.body["model"]) + + // Verify tool is present + tools := call.body["tools"].([]any) + require.Len(t, tools, 1) + + tool := tools[0].(map[string]any) + require.Equal(t, "function", tool["type"]) + + function := tool["function"].(map[string]any) + require.Equal(t, "test-tool", function["name"]) + + // Verify tool_choice is set to "required" (not a function name) + toolChoice := call.body["tool_choice"] + require.Equal(t, "required", toolChoice) + }) + t.Run("should parse annotations/citations", func(t *testing.T) { t.Parallel()