image_upload_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"net/http"
  6	"os"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/anthropic"
 11	"charm.land/fantasy/providers/google"
 12	"charm.land/fantasy/providers/openai"
 13	"github.com/stretchr/testify/require"
 14	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 15)
 16
 17func anthropicImageBuilder(model string) builderFunc {
 18	return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
 19		provider, err := anthropic.New(
 20			anthropic.WithAPIKey(cmp.Or(os.Getenv("FANTASY_ANTHROPIC_API_KEY"), "(missing)")),
 21			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 22		)
 23		if err != nil {
 24			return nil, err
 25		}
 26		return provider.LanguageModel(model)
 27	}
 28}
 29
 30func openAIImageBuilder(model string) builderFunc {
 31	return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
 32		provider, err := openai.New(
 33			openai.WithAPIKey(cmp.Or(os.Getenv("FANTASY_OPENAI_API_KEY"), "(missing)")),
 34			openai.WithHTTPClient(&http.Client{Transport: r}),
 35		)
 36		if err != nil {
 37			return nil, err
 38		}
 39		return provider.LanguageModel(model)
 40	}
 41}
 42
 43func geminiImageBuilder(model string) builderFunc {
 44	return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
 45		provider, err := google.New(
 46			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
 47			google.WithHTTPClient(&http.Client{Transport: r}),
 48		)
 49		if err != nil {
 50			return nil, err
 51		}
 52		return provider.LanguageModel(model)
 53	}
 54}
 55
 56func TestImageUploadAgent(t *testing.T) {
 57	pairs := []builderPair{
 58		{
 59			name:    "anthropic-claude-sonnet-4",
 60			builder: anthropicImageBuilder("claude-sonnet-4-20250514"),
 61		},
 62		{
 63			name:    "openai-gpt-5",
 64			builder: openAIImageBuilder("gpt-5"),
 65		},
 66		{
 67			name:    "gemini-2.5-pro",
 68			builder: geminiImageBuilder("gemini-2.5-pro"),
 69		},
 70	}
 71
 72	img, err := os.ReadFile("testdata/wish.png")
 73	require.NoError(t, err)
 74
 75	file := fantasy.FilePart{Filename: "wish.png", Data: img, MediaType: "image/png"}
 76
 77	for _, pair := range pairs {
 78		pair := pair
 79		t.Run(pair.name, func(t *testing.T) {
 80			r := newRecorder(t)
 81
 82			lm, err := pair.builder(r)
 83			require.NoError(t, err)
 84
 85			agent := fantasy.NewAgent(
 86				lm,
 87				fantasy.WithSystemPrompt("You are a helpful assistant"),
 88			)
 89
 90			result, err := agent.Generate(t.Context(), fantasy.AgentCall{
 91				Prompt:          "Describe the image briefly in English.",
 92				Files:           []fantasy.FilePart{file},
 93				ProviderOptions: pair.providerOptions,
 94				MaxOutputTokens: fantasy.Opt(int64(4000)),
 95			})
 96			require.NoError(t, err)
 97			got := result.Response.Content.Text()
 98			require.NotEmpty(t, got, "expected non-empty description for %s", pair.name)
 99		})
100	}
101}
102
103func TestImageUploadAgentStreaming(t *testing.T) {
104	pairs := []builderPair{
105		{
106			name:    "anthropic-claude-sonnet-4",
107			builder: anthropicImageBuilder("claude-sonnet-4-20250514"),
108		},
109		{
110			name:    "openai-gpt-5",
111			builder: openAIImageBuilder("gpt-5"),
112		},
113		{
114			name:    "gemini-2.5-pro",
115			builder: geminiImageBuilder("gemini-2.5-pro"),
116		},
117	}
118
119	img, err := os.ReadFile("testdata/wish.png")
120	require.NoError(t, err)
121
122	file := fantasy.FilePart{Filename: "wish.png", Data: img, MediaType: "image/png"}
123
124	for _, pair := range pairs {
125		pair := pair
126		t.Run(pair.name+"-stream", func(t *testing.T) {
127			r := newRecorder(t)
128
129			lm, err := pair.builder(r)
130			require.NoError(t, err)
131
132			agent := fantasy.NewAgent(
133				lm,
134				fantasy.WithSystemPrompt("You are a helpful assistant"),
135			)
136
137			result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
138				Prompt:          "Describe the image briefly in English.",
139				Files:           []fantasy.FilePart{file},
140				ProviderOptions: pair.providerOptions,
141				MaxOutputTokens: fantasy.Opt(int64(4000)),
142			})
143			require.NoError(t, err)
144			got := result.Response.Content.Text()
145			require.NotEmpty(t, got, "expected non-empty description for %s", pair.name)
146		})
147	}
148}