1package providertests
2
3import (
4 "context"
5 "strings"
6 "testing"
7
8 "charm.land/fantasy"
9 "charm.land/x/vcr"
10 "github.com/stretchr/testify/require"
11)
12
13// Object generation tests for providers.
14//
15// These test functions can be used to test structured object generation
16// (GenerateObject and StreamObject) for any provider implementation.
17//
18// Usage example:
19//
20// func TestMyProviderObjectGeneration(t *testing.T) {
21// var pairs []builderPair
22// for _, m := range myTestModels {
23// pairs = append(pairs, builderPair{m.name, myBuilder(m.model), nil, nil})
24// }
25// testObjectGeneration(t, pairs)
26// }
27//
28// The tests cover:
29// - Simple object generation (flat schema with basic types)
30// - Complex object generation (nested objects and arrays)
31// - Streaming object generation (progressive updates)
32// - Object generation with custom repair functions
33
34// testObjectGeneration tests structured object generation for a provider.
35// It includes both non-streaming (GenerateObject) and streaming (StreamObject) tests.
36func testObjectGeneration(t *testing.T, pairs []builderPair) {
37 for _, pair := range pairs {
38 t.Run(pair.name, func(t *testing.T) {
39 testSimpleObject(t, pair)
40 testComplexObject(t, pair)
41 })
42 }
43}
44
45func testSimpleObject(t *testing.T, pair builderPair) {
46 // Define a simple schema for a person object
47 schema := fantasy.Schema{
48 Type: "object",
49 Properties: map[string]*fantasy.Schema{
50 "name": {
51 Type: "string",
52 Description: "The person's name",
53 },
54 "age": {
55 Type: "integer",
56 Description: "The person's age",
57 },
58 "city": {
59 Type: "string",
60 Description: "The city where the person lives",
61 },
62 },
63 Required: []string{"name", "age", "city"},
64 }
65
66 checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
67 require.NotNil(t, obj, "object should not be nil")
68 require.NotEmpty(t, rawText, "raw text should not be empty")
69 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
70
71 // Validate structure
72 objMap, ok := obj.(map[string]any)
73 require.True(t, ok, "object should be a map")
74 require.Contains(t, objMap, "name")
75 require.Contains(t, objMap, "age")
76 require.Contains(t, objMap, "city")
77
78 // Validate types
79 name, ok := objMap["name"].(string)
80 require.True(t, ok, "name should be a string")
81 require.NotEmpty(t, name, "name should not be empty")
82
83 // Age could be float64 from JSON unmarshaling
84 age, ok := objMap["age"].(float64)
85 require.True(t, ok, "age should be a number")
86 require.Greater(t, age, 0.0, "age should be greater than 0")
87
88 city, ok := objMap["city"].(string)
89 require.True(t, ok, "city should be a string")
90 require.NotEmpty(t, city, "city should not be empty")
91 }
92
93 t.Run("simple object", func(t *testing.T) {
94 r := vcr.NewRecorder(t)
95
96 languageModel, err := pair.builder(t, r)
97 require.NoError(t, err, "failed to build language model")
98
99 prompt := fantasy.Prompt{
100 fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
101 }
102
103 response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
104 Prompt: prompt,
105 Schema: schema,
106 SchemaName: "Person",
107 SchemaDescription: "A person with name, age, and city",
108 MaxOutputTokens: fantasy.Opt(int64(4000)),
109 ProviderOptions: pair.providerOptions,
110 })
111 require.NoError(t, err, "failed to generate object")
112 require.NotNil(t, response, "response should not be nil")
113 checkResult(t, response.Object, response.RawText, response.Usage)
114 })
115
116 t.Run("simple object streaming", func(t *testing.T) {
117 r := vcr.NewRecorder(t)
118
119 languageModel, err := pair.builder(t, r)
120 require.NoError(t, err, "failed to build language model")
121
122 prompt := fantasy.Prompt{
123 fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
124 }
125
126 stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
127 Prompt: prompt,
128 Schema: schema,
129 SchemaName: "Person",
130 SchemaDescription: "A person with name, age, and city",
131 MaxOutputTokens: fantasy.Opt(int64(4000)),
132 ProviderOptions: pair.providerOptions,
133 })
134 require.NoError(t, err, "failed to create object stream")
135 require.NotNil(t, stream, "stream should not be nil")
136
137 var lastObject any
138 var rawText string
139 var usage fantasy.Usage
140 var finishReason fantasy.FinishReason
141 objectCount := 0
142
143 for part := range stream {
144 switch part.Type {
145 case fantasy.ObjectStreamPartTypeObject:
146 lastObject = part.Object
147 objectCount++
148 case fantasy.ObjectStreamPartTypeTextDelta:
149 rawText += part.Delta
150 case fantasy.ObjectStreamPartTypeFinish:
151 usage = part.Usage
152 finishReason = part.FinishReason
153 case fantasy.ObjectStreamPartTypeError:
154 t.Fatalf("stream error: %v", part.Error)
155 }
156 }
157
158 require.NotNil(t, lastObject, "should have received at least one object")
159 require.Greater(t, objectCount, 0, "should have received object updates")
160 require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
161
162 // Validate object structure without requiring rawText (may be empty in tool-based mode)
163 require.NotNil(t, lastObject, "object should not be nil")
164 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
165
166 // Validate structure
167 objMap, ok := lastObject.(map[string]any)
168 require.True(t, ok, "object should be a map")
169 require.Contains(t, objMap, "name")
170 require.Contains(t, objMap, "age")
171 require.Contains(t, objMap, "city")
172
173 // Validate types
174 name, ok := objMap["name"].(string)
175 require.True(t, ok, "name should be a string")
176 require.NotEmpty(t, name, "name should not be empty")
177
178 // Age could be float64 from JSON unmarshaling
179 age, ok := objMap["age"].(float64)
180 require.True(t, ok, "age should be a number")
181 require.Greater(t, age, 0.0, "age should be greater than 0")
182
183 city, ok := objMap["city"].(string)
184 require.True(t, ok, "city should be a string")
185 require.NotEmpty(t, city, "city should not be empty")
186 })
187}
188
189func testComplexObject(t *testing.T, pair builderPair) {
190 // Define a more complex schema with nested objects and arrays
191 schema := fantasy.Schema{
192 Type: "object",
193 Properties: map[string]*fantasy.Schema{
194 "title": {
195 Type: "string",
196 Description: "The book title",
197 },
198 "author": {
199 Type: "object",
200 Properties: map[string]*fantasy.Schema{
201 "name": {
202 Type: "string",
203 Description: "Author's name",
204 },
205 "nationality": {
206 Type: "string",
207 Description: "Author's nationality",
208 },
209 },
210 Required: []string{"name", "nationality"},
211 },
212 "genres": {
213 Type: "array",
214 Items: &fantasy.Schema{
215 Type: "string",
216 },
217 Description: "List of genres",
218 },
219 "published_year": {
220 Type: "integer",
221 Description: "Year the book was published",
222 },
223 },
224 Required: []string{"title", "author", "genres", "published_year"},
225 }
226
227 checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
228 require.NotNil(t, obj, "object should not be nil")
229 require.NotEmpty(t, rawText, "raw text should not be empty")
230 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
231
232 // Validate structure
233 objMap, ok := obj.(map[string]any)
234 require.True(t, ok, "object should be a map")
235 require.Contains(t, objMap, "title")
236 require.Contains(t, objMap, "author")
237 require.Contains(t, objMap, "genres")
238 require.Contains(t, objMap, "published_year")
239
240 // Validate title
241 title, ok := objMap["title"].(string)
242 require.True(t, ok, "title should be a string")
243 require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
244
245 // Validate nested author object
246 author, ok := objMap["author"].(map[string]any)
247 require.True(t, ok, "author should be an object")
248 require.Contains(t, author, "name")
249 require.Contains(t, author, "nationality")
250
251 // Validate genres array
252 genres, ok := objMap["genres"].([]any)
253 require.True(t, ok, "genres should be an array")
254 require.Greater(t, len(genres), 0, "genres should have at least one item")
255 for _, genre := range genres {
256 _, ok := genre.(string)
257 require.True(t, ok, "each genre should be a string")
258 }
259
260 // Validate published_year
261 year, ok := objMap["published_year"].(float64)
262 require.True(t, ok, "published_year should be a number")
263 require.Greater(t, year, 1900.0, "published_year should be after 1900")
264 }
265
266 t.Run("complex object", func(t *testing.T) {
267 r := vcr.NewRecorder(t)
268
269 languageModel, err := pair.builder(t, r)
270 require.NoError(t, err, "failed to build language model")
271
272 prompt := fantasy.Prompt{
273 fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
274 }
275
276 response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
277 Prompt: prompt,
278 Schema: schema,
279 SchemaName: "Book",
280 SchemaDescription: "A book with title, author, genres, and publication year",
281 MaxOutputTokens: fantasy.Opt(int64(4000)),
282 ProviderOptions: pair.providerOptions,
283 })
284 require.NoError(t, err, "failed to generate object")
285 require.NotNil(t, response, "response should not be nil")
286 checkResult(t, response.Object, response.RawText, response.Usage)
287 })
288
289 t.Run("complex object streaming", func(t *testing.T) {
290 r := vcr.NewRecorder(t)
291
292 languageModel, err := pair.builder(t, r)
293 require.NoError(t, err, "failed to build language model")
294
295 prompt := fantasy.Prompt{
296 fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
297 }
298
299 stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
300 Prompt: prompt,
301 Schema: schema,
302 SchemaName: "Book",
303 SchemaDescription: "A book with title, author, genres, and publication year",
304 MaxOutputTokens: fantasy.Opt(int64(4000)),
305 ProviderOptions: pair.providerOptions,
306 })
307 require.NoError(t, err, "failed to create object stream")
308 require.NotNil(t, stream, "stream should not be nil")
309
310 var lastObject any
311 var rawText string
312 var usage fantasy.Usage
313 var finishReason fantasy.FinishReason
314 objectCount := 0
315
316 for part := range stream {
317 switch part.Type {
318 case fantasy.ObjectStreamPartTypeObject:
319 lastObject = part.Object
320 objectCount++
321 case fantasy.ObjectStreamPartTypeTextDelta:
322 rawText += part.Delta
323 case fantasy.ObjectStreamPartTypeFinish:
324 usage = part.Usage
325 finishReason = part.FinishReason
326 case fantasy.ObjectStreamPartTypeError:
327 t.Fatalf("stream error: %v", part.Error)
328 }
329 }
330
331 require.NotNil(t, lastObject, "should have received at least one object")
332 require.Greater(t, objectCount, 0, "should have received object updates")
333 require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
334
335 // Validate object structure without requiring rawText (may be empty in tool-based mode)
336 require.NotNil(t, lastObject, "object should not be nil")
337 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
338
339 // Validate structure
340 objMap, ok := lastObject.(map[string]any)
341 require.True(t, ok, "object should be a map")
342 require.Contains(t, objMap, "title")
343 require.Contains(t, objMap, "author")
344 require.Contains(t, objMap, "genres")
345 require.Contains(t, objMap, "published_year")
346
347 // Validate title
348 title, ok := objMap["title"].(string)
349 require.True(t, ok, "title should be a string")
350 require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
351
352 // Validate nested author object
353 author, ok := objMap["author"].(map[string]any)
354 require.True(t, ok, "author should be an object")
355 require.Contains(t, author, "name")
356 require.Contains(t, author, "nationality")
357
358 // Validate genres array
359 genres, ok := objMap["genres"].([]any)
360 require.True(t, ok, "genres should be an array")
361 require.Greater(t, len(genres), 0, "genres should have at least one item")
362 for _, genre := range genres {
363 _, ok := genre.(string)
364 require.True(t, ok, "each genre should be a string")
365 }
366
367 // Validate published_year
368 year, ok := objMap["published_year"].(float64)
369 require.True(t, ok, "published_year should be a number")
370 require.Greater(t, year, 1900.0, "published_year should be after 1900")
371 })
372}
373
374// testObjectWithRepair tests object generation with custom repair functionality.
375func testObjectWithRepair(t *testing.T, pairs []builderPair) {
376 for _, pair := range pairs {
377 t.Run(pair.name, func(t *testing.T) {
378 t.Run("object with repair", func(t *testing.T) {
379 r := vcr.NewRecorder(t)
380
381 languageModel, err := pair.builder(t, r)
382 require.NoError(t, err, "failed to build language model")
383
384 minVal := 1.0
385 schema := fantasy.Schema{
386 Type: "object",
387 Properties: map[string]*fantasy.Schema{
388 "count": {
389 Type: "integer",
390 Description: "A count that must be positive",
391 Minimum: &minVal,
392 },
393 },
394 Required: []string{"count"},
395 }
396
397 prompt := fantasy.Prompt{
398 fantasy.NewUserMessage("Return a count of 5"),
399 }
400
401 repairFunc := func(ctx context.Context, text string, err error) (string, error) {
402 // Simple repair: if the JSON is malformed, try to fix it
403 // This is a placeholder - real repair would be more sophisticated
404 return text, nil
405 }
406
407 response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
408 Prompt: prompt,
409 Schema: schema,
410 SchemaName: "Count",
411 SchemaDescription: "A simple count object",
412 MaxOutputTokens: fantasy.Opt(int64(4000)),
413 RepairText: repairFunc,
414 ProviderOptions: pair.providerOptions,
415 })
416 require.NoError(t, err, "failed to generate object")
417 require.NotNil(t, response, "response should not be nil")
418 require.NotNil(t, response.Object, "object should not be nil")
419 })
420 })
421 }
422}