provider_registry_test.go

  1package providertests
  2
  3import (
  4	"encoding/json"
  5	"testing"
  6
  7	"charm.land/fantasy"
  8	"charm.land/fantasy/providers/anthropic"
  9	"charm.land/fantasy/providers/google"
 10	"charm.land/fantasy/providers/openai"
 11	"charm.land/fantasy/providers/openaicompat"
 12	"charm.land/fantasy/providers/openrouter"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16func TestProviderRegistry_Serialization_OpenAIOptions(t *testing.T) {
 17	msg := fantasy.Message{
 18		Role: fantasy.MessageRoleUser,
 19		Content: []fantasy.MessagePart{
 20			fantasy.TextPart{Text: "hi"},
 21		},
 22		ProviderOptions: fantasy.ProviderOptions{
 23			openai.Name: &openai.ProviderOptions{User: fantasy.Opt("tester")},
 24		},
 25	}
 26
 27	data, err := json.Marshal(msg)
 28	require.NoError(t, err)
 29
 30	var raw struct {
 31		ProviderOptions map[string]map[string]any `json:"provider_options"`
 32	}
 33	require.NoError(t, json.Unmarshal(data, &raw))
 34
 35	po, ok := raw.ProviderOptions[openai.Name]
 36	require.True(t, ok)
 37	require.Equal(t, openai.TypeProviderOptions, po["type"]) // no magic strings
 38	// ensure inner data has the field we set
 39	inner, ok := po["data"].(map[string]any)
 40	require.True(t, ok)
 41	require.Equal(t, "tester", inner["user"])
 42
 43	var decoded fantasy.Message
 44	require.NoError(t, json.Unmarshal(data, &decoded))
 45
 46	got, ok := decoded.ProviderOptions[openai.Name]
 47	require.True(t, ok)
 48	opt, ok := got.(*openai.ProviderOptions)
 49	require.True(t, ok)
 50	require.NotNil(t, opt.User)
 51	require.Equal(t, "tester", *opt.User)
 52}
 53
 54func TestProviderRegistry_Serialization_OpenAIResponses(t *testing.T) {
 55	// Use ResponsesProviderOptions in provider options
 56	msg := fantasy.Message{
 57		Role: fantasy.MessageRoleUser,
 58		Content: []fantasy.MessagePart{
 59			fantasy.TextPart{Text: "hello"},
 60		},
 61		ProviderOptions: fantasy.ProviderOptions{
 62			openai.Name: &openai.ResponsesProviderOptions{
 63				PromptCacheKey:    fantasy.Opt("cache-key-1"),
 64				ParallelToolCalls: fantasy.Opt(true),
 65			},
 66		},
 67	}
 68
 69	data, err := json.Marshal(msg)
 70	require.NoError(t, err)
 71
 72	// JSON should include the typed wrapper with constant TypeResponsesProviderOptions
 73	var raw struct {
 74		ProviderOptions map[string]map[string]any `json:"provider_options"`
 75	}
 76	require.NoError(t, json.Unmarshal(data, &raw))
 77
 78	po := raw.ProviderOptions[openai.Name]
 79	require.Equal(t, openai.TypeResponsesProviderOptions, po["type"]) // no magic strings
 80	inner, ok := po["data"].(map[string]any)
 81	require.True(t, ok)
 82	require.Equal(t, "cache-key-1", inner["prompt_cache_key"])
 83	require.Equal(t, true, inner["parallel_tool_calls"])
 84
 85	// Unmarshal back and assert concrete type
 86	var decoded fantasy.Message
 87	require.NoError(t, json.Unmarshal(data, &decoded))
 88	got := decoded.ProviderOptions[openai.Name]
 89	reqOpts, ok := got.(*openai.ResponsesProviderOptions)
 90	require.True(t, ok)
 91	require.NotNil(t, reqOpts.PromptCacheKey)
 92	require.Equal(t, "cache-key-1", *reqOpts.PromptCacheKey)
 93	require.NotNil(t, reqOpts.ParallelToolCalls)
 94	require.Equal(t, true, *reqOpts.ParallelToolCalls)
 95}
 96
 97func TestProviderRegistry_Serialization_OpenAIResponsesReasoningMetadata(t *testing.T) {
 98	resp := fantasy.Response{
 99		Content: []fantasy.Content{
100			fantasy.TextContent{
101				Text: "",
102				ProviderMetadata: fantasy.ProviderMetadata{
103					openai.Name: &openai.ResponsesReasoningMetadata{
104						ItemID:  "item-123",
105						Summary: []string{"part1", "part2"},
106					},
107				},
108			},
109		},
110	}
111
112	data, err := json.Marshal(resp)
113	require.NoError(t, err)
114
115	// Ensure the provider metadata is wrapped with type using constant
116	var raw struct {
117		Content []struct {
118			Type string         `json:"type"`
119			Data map[string]any `json:"data"`
120		} `json:"content"`
121	}
122	require.NoError(t, json.Unmarshal(data, &raw))
123	require.Greater(t, len(raw.Content), 0)
124	tc := raw.Content[0]
125	pm, ok := tc.Data["provider_metadata"].(map[string]any)
126	require.True(t, ok)
127	om, ok := pm[openai.Name].(map[string]any)
128	require.True(t, ok)
129	require.Equal(t, openai.TypeResponsesReasoningMetadata, om["type"]) // no magic strings
130	inner, ok := om["data"].(map[string]any)
131	require.True(t, ok)
132	require.Equal(t, "item-123", inner["item_id"])
133
134	// Unmarshal back
135	var decoded fantasy.Response
136	require.NoError(t, json.Unmarshal(data, &decoded))
137	pmDecoded := decoded.Content[0].(fantasy.TextContent).ProviderMetadata
138	val, ok := pmDecoded[openai.Name]
139	require.True(t, ok)
140	meta, ok := val.(*openai.ResponsesReasoningMetadata)
141	require.True(t, ok)
142	require.Equal(t, "item-123", meta.ItemID)
143	require.Equal(t, []string{"part1", "part2"}, meta.Summary)
144}
145
146func TestProviderRegistry_Serialization_AnthropicOptions(t *testing.T) {
147	sendReasoning := true
148	msg := fantasy.Message{
149		Role: fantasy.MessageRoleUser,
150		Content: []fantasy.MessagePart{
151			fantasy.TextPart{Text: "test message"},
152		},
153		ProviderOptions: fantasy.ProviderOptions{
154			anthropic.Name: &anthropic.ProviderOptions{
155				SendReasoning: &sendReasoning,
156			},
157		},
158	}
159
160	data, err := json.Marshal(msg)
161	require.NoError(t, err)
162
163	var decoded fantasy.Message
164	require.NoError(t, json.Unmarshal(data, &decoded))
165
166	got, ok := decoded.ProviderOptions[anthropic.Name]
167	require.True(t, ok)
168	opt, ok := got.(*anthropic.ProviderOptions)
169	require.True(t, ok)
170	require.NotNil(t, opt.SendReasoning)
171	require.Equal(t, true, *opt.SendReasoning)
172}
173
174func TestProviderRegistry_Serialization_GoogleOptions(t *testing.T) {
175	msg := fantasy.Message{
176		Role: fantasy.MessageRoleUser,
177		Content: []fantasy.MessagePart{
178			fantasy.TextPart{Text: "test message"},
179		},
180		ProviderOptions: fantasy.ProviderOptions{
181			google.Name: &google.ProviderOptions{
182				CachedContent: "cached-123",
183				Threshold:     "BLOCK_ONLY_HIGH",
184			},
185		},
186	}
187
188	data, err := json.Marshal(msg)
189	require.NoError(t, err)
190
191	var decoded fantasy.Message
192	require.NoError(t, json.Unmarshal(data, &decoded))
193
194	got, ok := decoded.ProviderOptions[google.Name]
195	require.True(t, ok)
196	opt, ok := got.(*google.ProviderOptions)
197	require.True(t, ok)
198	require.Equal(t, "cached-123", opt.CachedContent)
199	require.Equal(t, "BLOCK_ONLY_HIGH", opt.Threshold)
200}
201
202func TestProviderRegistry_Serialization_OpenRouterOptions(t *testing.T) {
203	includeUsage := true
204	msg := fantasy.Message{
205		Role: fantasy.MessageRoleUser,
206		Content: []fantasy.MessagePart{
207			fantasy.TextPart{Text: "test message"},
208		},
209		ProviderOptions: fantasy.ProviderOptions{
210			openrouter.Name: &openrouter.ProviderOptions{
211				IncludeUsage: &includeUsage,
212				User:         fantasy.Opt("test-user"),
213			},
214		},
215	}
216
217	data, err := json.Marshal(msg)
218	require.NoError(t, err)
219
220	var decoded fantasy.Message
221	require.NoError(t, json.Unmarshal(data, &decoded))
222
223	got, ok := decoded.ProviderOptions[openrouter.Name]
224	require.True(t, ok)
225	opt, ok := got.(*openrouter.ProviderOptions)
226	require.True(t, ok)
227	require.NotNil(t, opt.IncludeUsage)
228	require.Equal(t, true, *opt.IncludeUsage)
229	require.NotNil(t, opt.User)
230	require.Equal(t, "test-user", *opt.User)
231}
232
233func TestProviderRegistry_Serialization_OpenAICompatOptions(t *testing.T) {
234	effort := openai.ReasoningEffortHigh
235	msg := fantasy.Message{
236		Role: fantasy.MessageRoleUser,
237		Content: []fantasy.MessagePart{
238			fantasy.TextPart{Text: "test message"},
239		},
240		ProviderOptions: fantasy.ProviderOptions{
241			openaicompat.Name: &openaicompat.ProviderOptions{
242				User:            fantasy.Opt("test-user"),
243				ReasoningEffort: &effort,
244			},
245		},
246	}
247
248	data, err := json.Marshal(msg)
249	require.NoError(t, err)
250
251	var decoded fantasy.Message
252	require.NoError(t, json.Unmarshal(data, &decoded))
253
254	got, ok := decoded.ProviderOptions[openaicompat.Name]
255	require.True(t, ok)
256	opt, ok := got.(*openaicompat.ProviderOptions)
257	require.True(t, ok)
258	require.NotNil(t, opt.User)
259	require.Equal(t, "test-user", *opt.User)
260	require.NotNil(t, opt.ReasoningEffort)
261	require.Equal(t, openai.ReasoningEffortHigh, *opt.ReasoningEffort)
262}
263
264func TestProviderRegistry_MultiProvider(t *testing.T) {
265	// Test with multiple providers in one message
266	sendReasoning := true
267	msg := fantasy.Message{
268		Role: fantasy.MessageRoleUser,
269		Content: []fantasy.MessagePart{
270			fantasy.TextPart{Text: "test"},
271		},
272		ProviderOptions: fantasy.ProviderOptions{
273			openai.Name: &openai.ProviderOptions{User: fantasy.Opt("user1")},
274			anthropic.Name: &anthropic.ProviderOptions{
275				SendReasoning: &sendReasoning,
276			},
277		},
278	}
279
280	data, err := json.Marshal(msg)
281	require.NoError(t, err)
282
283	var decoded fantasy.Message
284	require.NoError(t, json.Unmarshal(data, &decoded))
285
286	// Check OpenAI options
287	openaiOpt, ok := decoded.ProviderOptions[openai.Name]
288	require.True(t, ok)
289	openaiData, ok := openaiOpt.(*openai.ProviderOptions)
290	require.True(t, ok)
291	require.Equal(t, "user1", *openaiData.User)
292
293	// Check Anthropic options
294	anthropicOpt, ok := decoded.ProviderOptions[anthropic.Name]
295	require.True(t, ok)
296	anthropicData, ok := anthropicOpt.(*anthropic.ProviderOptions)
297	require.True(t, ok)
298	require.Equal(t, true, *anthropicData.SendReasoning)
299}
300
301func TestProviderRegistry_ErrorHandling(t *testing.T) {
302	t.Run("unknown provider type", func(t *testing.T) {
303		invalidJSON := `{
304			"role": "user",
305			"content": [{"type": "text", "data": {"text": "hi"}}],
306			"provider_options": {
307				"unknown": {
308					"type": "unknown.provider.type",
309					"data": {}
310				}
311			}
312		}`
313
314		var msg fantasy.Message
315		err := json.Unmarshal([]byte(invalidJSON), &msg)
316		require.Error(t, err)
317		require.Contains(t, err.Error(), "unknown provider data type")
318	})
319
320	t.Run("malformed provider data", func(t *testing.T) {
321		invalidJSON := `{
322			"role": "user",
323			"content": [{"type": "text", "data": {"text": "hi"}}],
324			"provider_options": {
325				"openai": "not-an-object"
326			}
327		}`
328
329		var msg fantasy.Message
330		err := json.Unmarshal([]byte(invalidJSON), &msg)
331		require.Error(t, err)
332	})
333}
334
335func TestProviderRegistry_AllTypesRegistered(t *testing.T) {
336	// Verify all expected provider types are registered
337	// We test that unmarshaling with proper type IDs doesn't fail with "unknown provider data type"
338	tests := []struct {
339		name         string
340		providerName string
341		data         fantasy.ProviderOptionsData
342	}{
343		{"OpenAI Options", openai.Name, &openai.ProviderOptions{}},
344		{"OpenAI File Options", openai.Name, &openai.ProviderFileOptions{}},
345		{"OpenAI Metadata", openai.Name, &openai.ProviderMetadata{}},
346		{"OpenAI Responses Options", openai.Name, &openai.ResponsesProviderOptions{}},
347		{"Anthropic Options", anthropic.Name, &anthropic.ProviderOptions{}},
348		{"Google Options", google.Name, &google.ProviderOptions{}},
349		{"OpenRouter Options", openrouter.Name, &openrouter.ProviderOptions{}},
350		{"OpenAICompat Options", openaicompat.Name, &openaicompat.ProviderOptions{}},
351	}
352
353	for _, tc := range tests {
354		t.Run(tc.name, func(t *testing.T) {
355			// Create a message with the provider options
356			msg := fantasy.Message{
357				Role: fantasy.MessageRoleUser,
358				Content: []fantasy.MessagePart{
359					fantasy.TextPart{Text: "test"},
360				},
361				ProviderOptions: fantasy.ProviderOptions{
362					tc.providerName: tc.data,
363				},
364			}
365
366			// Marshal and unmarshal
367			data, err := json.Marshal(msg)
368			require.NoError(t, err)
369
370			var decoded fantasy.Message
371			err = json.Unmarshal(data, &decoded)
372			require.NoError(t, err)
373
374			// Verify the provider options exist
375			_, ok := decoded.ProviderOptions[tc.providerName]
376			require.True(t, ok, "Provider options should be present after round-trip")
377		})
378	}
379
380	// Test metadata types separately as they go in different field
381	metadataTests := []struct {
382		name         string
383		providerName string
384		data         fantasy.ProviderOptionsData
385	}{
386		{"OpenAI Responses Reasoning Metadata", openai.Name, &openai.ResponsesReasoningMetadata{}},
387		{"Anthropic Reasoning Metadata", anthropic.Name, &anthropic.ReasoningOptionMetadata{}},
388		{"Google Reasoning Metadata", google.Name, &google.ReasoningMetadata{}},
389		{"OpenRouter Metadata", openrouter.Name, &openrouter.ProviderMetadata{}},
390	}
391
392	for _, tc := range metadataTests {
393		t.Run(tc.name, func(t *testing.T) {
394			// Create a response with provider metadata
395			resp := fantasy.Response{
396				Content: []fantasy.Content{
397					fantasy.TextContent{
398						Text: "test",
399						ProviderMetadata: fantasy.ProviderMetadata{
400							tc.providerName: tc.data,
401						},
402					},
403				},
404			}
405
406			// Marshal and unmarshal
407			data, err := json.Marshal(resp)
408			require.NoError(t, err)
409
410			var decoded fantasy.Response
411			err = json.Unmarshal(data, &decoded)
412			require.NoError(t, err)
413
414			// Verify the provider metadata exists
415			textContent, ok := decoded.Content[0].(fantasy.TextContent)
416			require.True(t, ok)
417			_, ok = textContent.ProviderMetadata[tc.providerName]
418			require.True(t, ok, "Provider metadata should be present after round-trip")
419		})
420	}
421}