Handle ToolChoiceRequired for OpenAI (#119)

h8mankind created

* Fix OpenAI provider to handle ToolChoiceRequired

The OpenAI provider's toOpenAiTools() function was missing a case for fantasy.ToolChoiceRequired, causing it to fall through to the default case and treat "required" as a function name.

This resulted in API errors: 'Invalid value for function_call: no function named "required" was specified in the functions parameter.

OpenAI's API supports tool_choice: "required" (added April 2024), so this fix adds the missing case to properly map ToolChoiceRequired to "required" in the OpenAI API format.

* Add test for ToolChoiceRequired handling

Change summary

providers/openai/language_model.go |  4 ++
providers/openai/openai_test.go    | 59 ++++++++++++++++++++++++++++++++
2 files changed, 63 insertions(+)

Detailed changes

providers/openai/language_model.go 🔗

@@ -619,6 +619,10 @@ func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAi
 		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 			OfAuto: param.NewOpt("none"),
 		}
+	case fantasy.ToolChoiceRequired:
+		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+			OfAuto: param.NewOpt("required"),
+		}
 	default:
 		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{

providers/openai/openai_test.go 🔗

@@ -1288,6 +1288,65 @@ func TestDoGenerate(t *testing.T) {
 		require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
 	})
 
+	t.Run("should handle ToolChoiceRequired", func(t *testing.T) {
+		t.Parallel()
+
+		server := newMockServer()
+		defer server.close()
+
+		server.prepareJSONResponse(map[string]any{
+			"content": "",
+		})
+
+		provider, err := New(
+			WithAPIKey("test-api-key"),
+			WithBaseURL(server.server.URL),
+		)
+		require.NoError(t, err)
+		model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
+
+		_, err = model.Generate(context.Background(), fantasy.Call{
+			Prompt: testPrompt,
+			Tools: []fantasy.Tool{
+				fantasy.FunctionTool{
+					Name: "test-tool",
+					InputSchema: map[string]any{
+						"type": "object",
+						"properties": map[string]any{
+							"value": map[string]any{
+								"type": "string",
+							},
+						},
+						"required":             []string{"value"},
+						"additionalProperties": false,
+						"$schema":              "http://json-schema.org/draft-07/schema#",
+					},
+				},
+			},
+			ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0],
+		})
+
+		require.NoError(t, err)
+		require.Len(t, server.calls, 1)
+
+		call := server.calls[0]
+		require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+		// Verify tool is present
+		tools := call.body["tools"].([]any)
+		require.Len(t, tools, 1)
+
+		tool := tools[0].(map[string]any)
+		require.Equal(t, "function", tool["type"])
+
+		function := tool["function"].(map[string]any)
+		require.Equal(t, "test-tool", function["name"])
+
+		// Verify tool_choice is set to "required" (not a function name)
+		toolChoice := call.body["tool_choice"]
+		require.Equal(t, "required", toolChoice)
+	})
+
 	t.Run("should parse annotations/citations", func(t *testing.T) {
 		t.Parallel()