1package providertests
2
3import (
4 "context"
5 "strings"
6 "testing"
7
8 "charm.land/fantasy"
9 "github.com/stretchr/testify/require"
10)
11
12// Object generation tests for providers.
13//
14// These test functions can be used to test structured object generation
15// (GenerateObject and StreamObject) for any provider implementation.
16//
17// Usage example:
18//
19// func TestMyProviderObjectGeneration(t *testing.T) {
20// var pairs []builderPair
21// for _, m := range myTestModels {
22// pairs = append(pairs, builderPair{m.name, myBuilder(m.model), nil, nil})
23// }
24// testObjectGeneration(t, pairs)
25// }
26//
27// The tests cover:
28// - Simple object generation (flat schema with basic types)
29// - Complex object generation (nested objects and arrays)
30// - Streaming object generation (progressive updates)
31// - Object generation with custom repair functions
32
33// testObjectGeneration tests structured object generation for a provider.
34// It includes both non-streaming (GenerateObject) and streaming (StreamObject) tests.
35func testObjectGeneration(t *testing.T, pairs []builderPair) {
36 for _, pair := range pairs {
37 t.Run(pair.name, func(t *testing.T) {
38 testSimpleObject(t, pair)
39 testComplexObject(t, pair)
40 })
41 }
42}
43
44func testSimpleObject(t *testing.T, pair builderPair) {
45 // Define a simple schema for a person object
46 schema := fantasy.Schema{
47 Type: "object",
48 Properties: map[string]*fantasy.Schema{
49 "name": {
50 Type: "string",
51 Description: "The person's name",
52 },
53 "age": {
54 Type: "integer",
55 Description: "The person's age",
56 },
57 "city": {
58 Type: "string",
59 Description: "The city where the person lives",
60 },
61 },
62 Required: []string{"name", "age", "city"},
63 }
64
65 checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
66 require.NotNil(t, obj, "object should not be nil")
67 require.NotEmpty(t, rawText, "raw text should not be empty")
68 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
69
70 // Validate structure
71 objMap, ok := obj.(map[string]any)
72 require.True(t, ok, "object should be a map")
73 require.Contains(t, objMap, "name")
74 require.Contains(t, objMap, "age")
75 require.Contains(t, objMap, "city")
76
77 // Validate types
78 name, ok := objMap["name"].(string)
79 require.True(t, ok, "name should be a string")
80 require.NotEmpty(t, name, "name should not be empty")
81
82 // Age could be float64 from JSON unmarshaling
83 age, ok := objMap["age"].(float64)
84 require.True(t, ok, "age should be a number")
85 require.Greater(t, age, 0.0, "age should be greater than 0")
86
87 city, ok := objMap["city"].(string)
88 require.True(t, ok, "city should be a string")
89 require.NotEmpty(t, city, "city should not be empty")
90 }
91
92 t.Run("simple object", func(t *testing.T) {
93 r := newRecorder(t)
94
95 languageModel, err := pair.builder(t, r)
96 require.NoError(t, err, "failed to build language model")
97
98 prompt := fantasy.Prompt{
99 fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
100 }
101
102 response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
103 Prompt: prompt,
104 Schema: schema,
105 SchemaName: "Person",
106 SchemaDescription: "A person with name, age, and city",
107 MaxOutputTokens: fantasy.Opt(int64(4000)),
108 ProviderOptions: pair.providerOptions,
109 })
110 require.NoError(t, err, "failed to generate object")
111 require.NotNil(t, response, "response should not be nil")
112 checkResult(t, response.Object, response.RawText, response.Usage)
113 })
114
115 t.Run("simple object streaming", func(t *testing.T) {
116 r := newRecorder(t)
117
118 languageModel, err := pair.builder(t, r)
119 require.NoError(t, err, "failed to build language model")
120
121 prompt := fantasy.Prompt{
122 fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
123 }
124
125 stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
126 Prompt: prompt,
127 Schema: schema,
128 SchemaName: "Person",
129 SchemaDescription: "A person with name, age, and city",
130 MaxOutputTokens: fantasy.Opt(int64(4000)),
131 ProviderOptions: pair.providerOptions,
132 })
133 require.NoError(t, err, "failed to create object stream")
134 require.NotNil(t, stream, "stream should not be nil")
135
136 var lastObject any
137 var rawText string
138 var usage fantasy.Usage
139 var finishReason fantasy.FinishReason
140 objectCount := 0
141
142 for part := range stream {
143 switch part.Type {
144 case fantasy.ObjectStreamPartTypeObject:
145 lastObject = part.Object
146 objectCount++
147 case fantasy.ObjectStreamPartTypeTextDelta:
148 rawText += part.Delta
149 case fantasy.ObjectStreamPartTypeFinish:
150 usage = part.Usage
151 finishReason = part.FinishReason
152 case fantasy.ObjectStreamPartTypeError:
153 t.Fatalf("stream error: %v", part.Error)
154 }
155 }
156
157 require.NotNil(t, lastObject, "should have received at least one object")
158 require.Greater(t, objectCount, 0, "should have received object updates")
159 require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
160
161 // Validate object structure without requiring rawText (may be empty in tool-based mode)
162 require.NotNil(t, lastObject, "object should not be nil")
163 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
164
165 // Validate structure
166 objMap, ok := lastObject.(map[string]any)
167 require.True(t, ok, "object should be a map")
168 require.Contains(t, objMap, "name")
169 require.Contains(t, objMap, "age")
170 require.Contains(t, objMap, "city")
171
172 // Validate types
173 name, ok := objMap["name"].(string)
174 require.True(t, ok, "name should be a string")
175 require.NotEmpty(t, name, "name should not be empty")
176
177 // Age could be float64 from JSON unmarshaling
178 age, ok := objMap["age"].(float64)
179 require.True(t, ok, "age should be a number")
180 require.Greater(t, age, 0.0, "age should be greater than 0")
181
182 city, ok := objMap["city"].(string)
183 require.True(t, ok, "city should be a string")
184 require.NotEmpty(t, city, "city should not be empty")
185 })
186}
187
188func testComplexObject(t *testing.T, pair builderPair) {
189 // Define a more complex schema with nested objects and arrays
190 schema := fantasy.Schema{
191 Type: "object",
192 Properties: map[string]*fantasy.Schema{
193 "title": {
194 Type: "string",
195 Description: "The book title",
196 },
197 "author": {
198 Type: "object",
199 Properties: map[string]*fantasy.Schema{
200 "name": {
201 Type: "string",
202 Description: "Author's name",
203 },
204 "nationality": {
205 Type: "string",
206 Description: "Author's nationality",
207 },
208 },
209 Required: []string{"name", "nationality"},
210 },
211 "genres": {
212 Type: "array",
213 Items: &fantasy.Schema{
214 Type: "string",
215 },
216 Description: "List of genres",
217 },
218 "published_year": {
219 Type: "integer",
220 Description: "Year the book was published",
221 },
222 },
223 Required: []string{"title", "author", "genres", "published_year"},
224 }
225
226 checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
227 require.NotNil(t, obj, "object should not be nil")
228 require.NotEmpty(t, rawText, "raw text should not be empty")
229 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
230
231 // Validate structure
232 objMap, ok := obj.(map[string]any)
233 require.True(t, ok, "object should be a map")
234 require.Contains(t, objMap, "title")
235 require.Contains(t, objMap, "author")
236 require.Contains(t, objMap, "genres")
237 require.Contains(t, objMap, "published_year")
238
239 // Validate title
240 title, ok := objMap["title"].(string)
241 require.True(t, ok, "title should be a string")
242 require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
243
244 // Validate nested author object
245 author, ok := objMap["author"].(map[string]any)
246 require.True(t, ok, "author should be an object")
247 require.Contains(t, author, "name")
248 require.Contains(t, author, "nationality")
249
250 // Validate genres array
251 genres, ok := objMap["genres"].([]any)
252 require.True(t, ok, "genres should be an array")
253 require.Greater(t, len(genres), 0, "genres should have at least one item")
254 for _, genre := range genres {
255 _, ok := genre.(string)
256 require.True(t, ok, "each genre should be a string")
257 }
258
259 // Validate published_year
260 year, ok := objMap["published_year"].(float64)
261 require.True(t, ok, "published_year should be a number")
262 require.Greater(t, year, 1900.0, "published_year should be after 1900")
263 }
264
265 t.Run("complex object", func(t *testing.T) {
266 r := newRecorder(t)
267
268 languageModel, err := pair.builder(t, r)
269 require.NoError(t, err, "failed to build language model")
270
271 prompt := fantasy.Prompt{
272 fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
273 }
274
275 response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
276 Prompt: prompt,
277 Schema: schema,
278 SchemaName: "Book",
279 SchemaDescription: "A book with title, author, genres, and publication year",
280 MaxOutputTokens: fantasy.Opt(int64(4000)),
281 ProviderOptions: pair.providerOptions,
282 })
283 require.NoError(t, err, "failed to generate object")
284 require.NotNil(t, response, "response should not be nil")
285 checkResult(t, response.Object, response.RawText, response.Usage)
286 })
287
288 t.Run("complex object streaming", func(t *testing.T) {
289 r := newRecorder(t)
290
291 languageModel, err := pair.builder(t, r)
292 require.NoError(t, err, "failed to build language model")
293
294 prompt := fantasy.Prompt{
295 fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
296 }
297
298 stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
299 Prompt: prompt,
300 Schema: schema,
301 SchemaName: "Book",
302 SchemaDescription: "A book with title, author, genres, and publication year",
303 MaxOutputTokens: fantasy.Opt(int64(4000)),
304 ProviderOptions: pair.providerOptions,
305 })
306 require.NoError(t, err, "failed to create object stream")
307 require.NotNil(t, stream, "stream should not be nil")
308
309 var lastObject any
310 var rawText string
311 var usage fantasy.Usage
312 var finishReason fantasy.FinishReason
313 objectCount := 0
314
315 for part := range stream {
316 switch part.Type {
317 case fantasy.ObjectStreamPartTypeObject:
318 lastObject = part.Object
319 objectCount++
320 case fantasy.ObjectStreamPartTypeTextDelta:
321 rawText += part.Delta
322 case fantasy.ObjectStreamPartTypeFinish:
323 usage = part.Usage
324 finishReason = part.FinishReason
325 case fantasy.ObjectStreamPartTypeError:
326 t.Fatalf("stream error: %v", part.Error)
327 }
328 }
329
330 require.NotNil(t, lastObject, "should have received at least one object")
331 require.Greater(t, objectCount, 0, "should have received object updates")
332 require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
333
334 // Validate object structure without requiring rawText (may be empty in tool-based mode)
335 require.NotNil(t, lastObject, "object should not be nil")
336 require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
337
338 // Validate structure
339 objMap, ok := lastObject.(map[string]any)
340 require.True(t, ok, "object should be a map")
341 require.Contains(t, objMap, "title")
342 require.Contains(t, objMap, "author")
343 require.Contains(t, objMap, "genres")
344 require.Contains(t, objMap, "published_year")
345
346 // Validate title
347 title, ok := objMap["title"].(string)
348 require.True(t, ok, "title should be a string")
349 require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
350
351 // Validate nested author object
352 author, ok := objMap["author"].(map[string]any)
353 require.True(t, ok, "author should be an object")
354 require.Contains(t, author, "name")
355 require.Contains(t, author, "nationality")
356
357 // Validate genres array
358 genres, ok := objMap["genres"].([]any)
359 require.True(t, ok, "genres should be an array")
360 require.Greater(t, len(genres), 0, "genres should have at least one item")
361 for _, genre := range genres {
362 _, ok := genre.(string)
363 require.True(t, ok, "each genre should be a string")
364 }
365
366 // Validate published_year
367 year, ok := objMap["published_year"].(float64)
368 require.True(t, ok, "published_year should be a number")
369 require.Greater(t, year, 1900.0, "published_year should be after 1900")
370 })
371}
372
373// testObjectWithRepair tests object generation with custom repair functionality.
374func testObjectWithRepair(t *testing.T, pairs []builderPair) {
375 for _, pair := range pairs {
376 t.Run(pair.name, func(t *testing.T) {
377 t.Run("object with repair", func(t *testing.T) {
378 r := newRecorder(t)
379
380 languageModel, err := pair.builder(t, r)
381 require.NoError(t, err, "failed to build language model")
382
383 minVal := 1.0
384 schema := fantasy.Schema{
385 Type: "object",
386 Properties: map[string]*fantasy.Schema{
387 "count": {
388 Type: "integer",
389 Description: "A count that must be positive",
390 Minimum: &minVal,
391 },
392 },
393 Required: []string{"count"},
394 }
395
396 prompt := fantasy.Prompt{
397 fantasy.NewUserMessage("Return a count of 5"),
398 }
399
400 repairFunc := func(ctx context.Context, text string, err error) (string, error) {
401 // Simple repair: if the JSON is malformed, try to fix it
402 // This is a placeholder - real repair would be more sophisticated
403 return text, nil
404 }
405
406 response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
407 Prompt: prompt,
408 Schema: schema,
409 SchemaName: "Count",
410 SchemaDescription: "A simple count object",
411 MaxOutputTokens: fantasy.Opt(int64(4000)),
412 RepairText: repairFunc,
413 ProviderOptions: pair.providerOptions,
414 })
415 require.NoError(t, err, "failed to generate object")
416 require.NotNil(t, response, "response should not be nil")
417 require.NotNil(t, response.Object, "object should not be nil")
418 })
419 })
420 }
421}