image_upload_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"net/http"
  6	"os"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/anthropic"
 11	"charm.land/fantasy/providers/google"
 12	"charm.land/fantasy/providers/openai"
 13	"github.com/stretchr/testify/require"
 14	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 15)
 16
 17func anthropicImageBuilder(model string) builderFunc {
 18	return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
 19		provider := anthropic.New(
 20			anthropic.WithAPIKey(cmp.Or(os.Getenv("FANTASY_ANTHROPIC_API_KEY"), "(missing)")),
 21			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 22		)
 23		return provider.LanguageModel(model)
 24	}
 25}
 26
 27func openAIImageBuilder(model string) builderFunc {
 28	return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
 29		provider := openai.New(
 30			openai.WithAPIKey(cmp.Or(os.Getenv("FANTASY_OPENAI_API_KEY"), "(missing)")),
 31			openai.WithHTTPClient(&http.Client{Transport: r}),
 32		)
 33		return provider.LanguageModel(model)
 34	}
 35}
 36
 37func geminiImageBuilder(model string) builderFunc {
 38	return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
 39		provider := google.New(
 40			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
 41			google.WithHTTPClient(&http.Client{Transport: r}),
 42		)
 43		return provider.LanguageModel(model)
 44	}
 45}
 46
 47func TestImageUploadAgent(t *testing.T) {
 48	pairs := []builderPair{
 49		{
 50			name:    "anthropic-claude-sonnet-4",
 51			builder: anthropicImageBuilder("claude-sonnet-4-20250514"),
 52		},
 53		{
 54			name:    "openai-gpt-5",
 55			builder: openAIImageBuilder("gpt-5"),
 56		},
 57		{
 58			name:    "gemini-2.5-pro",
 59			builder: geminiImageBuilder("gemini-2.5-pro"),
 60		},
 61	}
 62
 63	img, err := os.ReadFile("testdata/wish.png")
 64	require.NoError(t, err)
 65
 66	file := fantasy.FilePart{Filename: "wish.png", Data: img, MediaType: "image/png"}
 67
 68	for _, pair := range pairs {
 69		pair := pair
 70		t.Run(pair.name, func(t *testing.T) {
 71			r := newRecorder(t)
 72
 73			lm, err := pair.builder(r)
 74			require.NoError(t, err)
 75
 76			agent := fantasy.NewAgent(
 77				lm,
 78				fantasy.WithSystemPrompt("You are a helpful assistant"),
 79			)
 80
 81			result, err := agent.Generate(t.Context(), fantasy.AgentCall{
 82				Prompt:          "Describe the image briefly in English.",
 83				Files:           []fantasy.FilePart{file},
 84				ProviderOptions: pair.providerOptions,
 85				MaxOutputTokens: fantasy.Opt(int64(4000)),
 86			})
 87			require.NoError(t, err)
 88			got := result.Response.Content.Text()
 89			require.NotEmpty(t, got, "expected non-empty description for %s", pair.name)
 90		})
 91	}
 92}
 93
 94func TestImageUploadAgentStreaming(t *testing.T) {
 95	pairs := []builderPair{
 96		{
 97			name:    "anthropic-claude-sonnet-4",
 98			builder: anthropicImageBuilder("claude-sonnet-4-20250514"),
 99		},
100		{
101			name:    "openai-gpt-5",
102			builder: openAIImageBuilder("gpt-5"),
103		},
104		{
105			name:    "gemini-2.5-pro",
106			builder: geminiImageBuilder("gemini-2.5-pro"),
107		},
108	}
109
110	img, err := os.ReadFile("testdata/wish.png")
111	require.NoError(t, err)
112
113	file := fantasy.FilePart{Filename: "wish.png", Data: img, MediaType: "image/png"}
114
115	for _, pair := range pairs {
116		pair := pair
117		t.Run(pair.name+"-stream", func(t *testing.T) {
118			r := newRecorder(t)
119
120			lm, err := pair.builder(r)
121			require.NoError(t, err)
122
123			agent := fantasy.NewAgent(
124				lm,
125				fantasy.WithSystemPrompt("You are a helpful assistant"),
126			)
127
128			result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
129				Prompt:          "Describe the image briefly in English.",
130				Files:           []fantasy.FilePart{file},
131				ProviderOptions: pair.providerOptions,
132				MaxOutputTokens: fantasy.Opt(int64(4000)),
133			})
134			require.NoError(t, err)
135			got := result.Response.Content.Text()
136			require.NotEmpty(t, got, "expected non-empty description for %s", pair.name)
137		})
138	}
139}