@@ -619,6 +619,10 @@ func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAi
openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: param.NewOpt("none"),
}
+ case fantasy.ToolChoiceRequired:
+ openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+ OfAuto: param.NewOpt("required"),
+ }
default:
openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
@@ -1288,6 +1288,65 @@ func TestDoGenerate(t *testing.T) {
require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
})
+ t.Run("should handle ToolChoiceRequired", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider, err := New(
+ WithAPIKey("test-api-key"),
+ WithBaseURL(server.server.URL),
+ )
+ require.NoError(t, err)
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
+
+ _, err = model.Generate(context.Background(), fantasy.Call{
+ Prompt: testPrompt,
+ Tools: []fantasy.Tool{
+ fantasy.FunctionTool{
+ Name: "test-tool",
+ InputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "value": map[string]any{
+ "type": "string",
+ },
+ },
+ "required": []string{"value"},
+ "additionalProperties": false,
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ },
+ },
+ },
+ ToolChoice: &[]fantasy.ToolChoice{fantasy.ToolChoiceRequired}[0],
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+ // Verify tool is present
+ tools := call.body["tools"].([]any)
+ require.Len(t, tools, 1)
+
+ tool := tools[0].(map[string]any)
+ require.Equal(t, "function", tool["type"])
+
+ function := tool["function"].(map[string]any)
+ require.Equal(t, "test-tool", function["name"])
+
+ // Verify tool_choice is set to "required" (not a function name)
+ toolChoice := call.body["tool_choice"]
+ require.Equal(t, "required", toolChoice)
+ })
+
t.Run("should parse annotations/citations", func(t *testing.T) {
t.Parallel()