openai_web_search_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"net/http"
  6	"os"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/openai"
 11	"charm.land/x/vcr"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15func openAIWebSearchBuilder(model string) builderFunc {
 16	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 17		opts := []openai.Option{
 18			openai.WithAPIKey(cmp.Or(os.Getenv("FANTASY_OPENAI_API_KEY"), os.Getenv("OPENAI_API_KEY"), "(missing)")),
 19			openai.WithHTTPClient(&http.Client{Transport: r}),
 20			openai.WithUseResponsesAPI(),
 21		}
 22		provider, err := openai.New(opts...)
 23		if err != nil {
 24			return nil, err
 25		}
 26		return provider.LanguageModel(t.Context(), model)
 27	}
 28}
 29
 30// TestOpenAIWebSearch tests web search tool support via the agent
 31// using WithProviderDefinedTools on the OpenAI Responses API.
 32func TestOpenAIWebSearch(t *testing.T) {
 33	model := "gpt-4.1"
 34	webSearchTool := openai.WebSearchTool(nil)
 35
 36	t.Run("generate", func(t *testing.T) {
 37		r := vcr.NewRecorder(t)
 38
 39		lm, err := openAIWebSearchBuilder(model)(t, r)
 40		require.NoError(t, err)
 41
 42		agent := fantasy.NewAgent(
 43			lm,
 44			fantasy.WithSystemPrompt("You are a helpful assistant"),
 45			fantasy.WithProviderDefinedTools(webSearchTool),
 46		)
 47
 48		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
 49			Prompt:          "What is the current population of Tokyo? Cite your source.",
 50			MaxOutputTokens: fantasy.Opt(int64(4000)),
 51		})
 52		require.NoError(t, err)
 53
 54		got := result.Response.Content.Text()
 55		require.NotEmpty(t, got, "should have a text response")
 56		require.Contains(t, got, "Tokyo", "response should mention Tokyo")
 57
 58		// Walk the steps and verify web search content was produced.
 59		var sources []fantasy.SourceContent
 60		var providerToolCalls []fantasy.ToolCallContent
 61		for _, step := range result.Steps {
 62			for _, c := range step.Content {
 63				switch v := c.(type) {
 64				case fantasy.ToolCallContent:
 65					if v.ProviderExecuted {
 66						providerToolCalls = append(providerToolCalls, v)
 67					}
 68				case fantasy.SourceContent:
 69					sources = append(sources, v)
 70				}
 71			}
 72		}
 73
 74		require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
 75		require.Equal(t, "web_search", providerToolCalls[0].ToolName)
 76		// Sources come from url_citation annotations; the model
 77		// may or may not include inline citations so we don't
 78		// require them, but if present they should have URLs.
 79		for _, src := range sources {
 80			require.NotEmpty(t, src.URL, "source should have a URL")
 81		}
 82	})
 83
 84	t.Run("stream", func(t *testing.T) {
 85		r := vcr.NewRecorder(t)
 86
 87		lm, err := openAIWebSearchBuilder(model)(t, r)
 88		require.NoError(t, err)
 89
 90		agent := fantasy.NewAgent(
 91			lm,
 92			fantasy.WithSystemPrompt("You are a helpful assistant"),
 93			fantasy.WithProviderDefinedTools(webSearchTool),
 94		)
 95
 96		// Turn 1: initial query triggers web search.
 97		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
 98			Prompt:          "What is the current population of Tokyo? Cite your source.",
 99			MaxOutputTokens: fantasy.Opt(int64(4000)),
100		})
101		require.NoError(t, err)
102
103		got := result.Response.Content.Text()
104		require.NotEmpty(t, got, "should have a text response")
105		require.Contains(t, got, "Tokyo", "response should mention Tokyo")
106
107		// Verify provider-executed tool calls and results in steps.
108		var providerToolCalls []fantasy.ToolCallContent
109		var providerToolResults []fantasy.ToolResultContent
110		for _, step := range result.Steps {
111			for _, c := range step.Content {
112				switch v := c.(type) {
113				case fantasy.ToolCallContent:
114					if v.ProviderExecuted {
115						providerToolCalls = append(providerToolCalls, v)
116					}
117				case fantasy.ToolResultContent:
118					if v.ProviderExecuted {
119						providerToolResults = append(providerToolResults, v)
120					}
121				}
122			}
123		}
124		require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
125		require.Equal(t, "web_search", providerToolCalls[0].ToolName)
126		require.NotEmpty(t, providerToolResults, "should have provider-executed tool results")
127	})
128}