object_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strings"
  6	"testing"
  7
  8	"charm.land/fantasy"
  9	"charm.land/x/vcr"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13// Object generation tests for providers.
 14//
 15// These test functions can be used to test structured object generation
 16// (GenerateObject and StreamObject) for any provider implementation.
 17//
 18// Usage example:
 19//
 20//	func TestMyProviderObjectGeneration(t *testing.T) {
 21//		var pairs []builderPair
 22//		for _, m := range myTestModels {
 23//			pairs = append(pairs, builderPair{m.name, myBuilder(m.model), nil, nil})
 24//		}
 25//		testObjectGeneration(t, pairs)
 26//	}
 27//
 28// The tests cover:
 29// - Simple object generation (flat schema with basic types)
 30// - Complex object generation (nested objects and arrays)
 31// - Streaming object generation (progressive updates)
 32// - Object generation with custom repair functions
 33
 34// testObjectGeneration tests structured object generation for a provider.
 35// It includes both non-streaming (GenerateObject) and streaming (StreamObject) tests.
 36func testObjectGeneration(t *testing.T, pairs []builderPair) {
 37	for _, pair := range pairs {
 38		t.Run(pair.name, func(t *testing.T) {
 39			testSimpleObject(t, pair)
 40			testComplexObject(t, pair)
 41		})
 42	}
 43}
 44
 45func testSimpleObject(t *testing.T, pair builderPair) {
 46	// Define a simple schema for a person object
 47	schema := fantasy.Schema{
 48		Type: "object",
 49		Properties: map[string]*fantasy.Schema{
 50			"name": {
 51				Type:        "string",
 52				Description: "The person's name",
 53			},
 54			"age": {
 55				Type:        "integer",
 56				Description: "The person's age",
 57			},
 58			"city": {
 59				Type:        "string",
 60				Description: "The city where the person lives",
 61			},
 62		},
 63		Required: []string{"name", "age", "city"},
 64	}
 65
 66	checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
 67		require.NotNil(t, obj, "object should not be nil")
 68		require.NotEmpty(t, rawText, "raw text should not be empty")
 69		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
 70
 71		// Validate structure
 72		objMap, ok := obj.(map[string]any)
 73		require.True(t, ok, "object should be a map")
 74		require.Contains(t, objMap, "name")
 75		require.Contains(t, objMap, "age")
 76		require.Contains(t, objMap, "city")
 77
 78		// Validate types
 79		name, ok := objMap["name"].(string)
 80		require.True(t, ok, "name should be a string")
 81		require.NotEmpty(t, name, "name should not be empty")
 82
 83		// Age could be float64 from JSON unmarshaling
 84		age, ok := objMap["age"].(float64)
 85		require.True(t, ok, "age should be a number")
 86		require.Greater(t, age, 0.0, "age should be greater than 0")
 87
 88		city, ok := objMap["city"].(string)
 89		require.True(t, ok, "city should be a string")
 90		require.NotEmpty(t, city, "city should not be empty")
 91	}
 92
 93	t.Run("simple object", func(t *testing.T) {
 94		r := vcr.NewRecorder(t)
 95
 96		languageModel, err := pair.builder(t, r)
 97		require.NoError(t, err, "failed to build language model")
 98
 99		prompt := fantasy.Prompt{
100			fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
101		}
102
103		response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
104			Prompt:            prompt,
105			Schema:            schema,
106			SchemaName:        "Person",
107			SchemaDescription: "A person with name, age, and city",
108			MaxOutputTokens:   fantasy.Opt(int64(4000)),
109			ProviderOptions:   pair.providerOptions,
110		})
111		require.NoError(t, err, "failed to generate object")
112		require.NotNil(t, response, "response should not be nil")
113		checkResult(t, response.Object, response.RawText, response.Usage)
114	})
115
116	t.Run("simple object streaming", func(t *testing.T) {
117		r := vcr.NewRecorder(t)
118
119		languageModel, err := pair.builder(t, r)
120		require.NoError(t, err, "failed to build language model")
121
122		prompt := fantasy.Prompt{
123			fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
124		}
125
126		stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
127			Prompt:            prompt,
128			Schema:            schema,
129			SchemaName:        "Person",
130			SchemaDescription: "A person with name, age, and city",
131			MaxOutputTokens:   fantasy.Opt(int64(4000)),
132			ProviderOptions:   pair.providerOptions,
133		})
134		require.NoError(t, err, "failed to create object stream")
135		require.NotNil(t, stream, "stream should not be nil")
136
137		var lastObject any
138		var rawText string
139		var usage fantasy.Usage
140		var finishReason fantasy.FinishReason
141		objectCount := 0
142
143		for part := range stream {
144			switch part.Type {
145			case fantasy.ObjectStreamPartTypeObject:
146				lastObject = part.Object
147				objectCount++
148			case fantasy.ObjectStreamPartTypeTextDelta:
149				rawText += part.Delta
150			case fantasy.ObjectStreamPartTypeFinish:
151				usage = part.Usage
152				finishReason = part.FinishReason
153			case fantasy.ObjectStreamPartTypeError:
154				t.Fatalf("stream error: %v", part.Error)
155			}
156		}
157
158		require.NotNil(t, lastObject, "should have received at least one object")
159		require.Greater(t, objectCount, 0, "should have received object updates")
160		require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
161
162		// Validate object structure without requiring rawText (may be empty in tool-based mode)
163		require.NotNil(t, lastObject, "object should not be nil")
164		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
165
166		// Validate structure
167		objMap, ok := lastObject.(map[string]any)
168		require.True(t, ok, "object should be a map")
169		require.Contains(t, objMap, "name")
170		require.Contains(t, objMap, "age")
171		require.Contains(t, objMap, "city")
172
173		// Validate types
174		name, ok := objMap["name"].(string)
175		require.True(t, ok, "name should be a string")
176		require.NotEmpty(t, name, "name should not be empty")
177
178		// Age could be float64 from JSON unmarshaling
179		age, ok := objMap["age"].(float64)
180		require.True(t, ok, "age should be a number")
181		require.Greater(t, age, 0.0, "age should be greater than 0")
182
183		city, ok := objMap["city"].(string)
184		require.True(t, ok, "city should be a string")
185		require.NotEmpty(t, city, "city should not be empty")
186	})
187}
188
189func testComplexObject(t *testing.T, pair builderPair) {
190	// Define a more complex schema with nested objects and arrays
191	schema := fantasy.Schema{
192		Type: "object",
193		Properties: map[string]*fantasy.Schema{
194			"title": {
195				Type:        "string",
196				Description: "The book title",
197			},
198			"author": {
199				Type: "object",
200				Properties: map[string]*fantasy.Schema{
201					"name": {
202						Type:        "string",
203						Description: "Author's name",
204					},
205					"nationality": {
206						Type:        "string",
207						Description: "Author's nationality",
208					},
209				},
210				Required: []string{"name", "nationality"},
211			},
212			"genres": {
213				Type: "array",
214				Items: &fantasy.Schema{
215					Type: "string",
216				},
217				Description: "List of genres",
218			},
219			"published_year": {
220				Type:        "integer",
221				Description: "Year the book was published",
222			},
223		},
224		Required: []string{"title", "author", "genres", "published_year"},
225	}
226
227	checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
228		require.NotNil(t, obj, "object should not be nil")
229		require.NotEmpty(t, rawText, "raw text should not be empty")
230		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
231
232		// Validate structure
233		objMap, ok := obj.(map[string]any)
234		require.True(t, ok, "object should be a map")
235		require.Contains(t, objMap, "title")
236		require.Contains(t, objMap, "author")
237		require.Contains(t, objMap, "genres")
238		require.Contains(t, objMap, "published_year")
239
240		// Validate title
241		title, ok := objMap["title"].(string)
242		require.True(t, ok, "title should be a string")
243		require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
244
245		// Validate nested author object
246		author, ok := objMap["author"].(map[string]any)
247		require.True(t, ok, "author should be an object")
248		require.Contains(t, author, "name")
249		require.Contains(t, author, "nationality")
250
251		// Validate genres array
252		genres, ok := objMap["genres"].([]any)
253		require.True(t, ok, "genres should be an array")
254		require.Greater(t, len(genres), 0, "genres should have at least one item")
255		for _, genre := range genres {
256			_, ok := genre.(string)
257			require.True(t, ok, "each genre should be a string")
258		}
259
260		// Validate published_year
261		year, ok := objMap["published_year"].(float64)
262		require.True(t, ok, "published_year should be a number")
263		require.Greater(t, year, 1900.0, "published_year should be after 1900")
264	}
265
266	t.Run("complex object", func(t *testing.T) {
267		r := vcr.NewRecorder(t)
268
269		languageModel, err := pair.builder(t, r)
270		require.NoError(t, err, "failed to build language model")
271
272		prompt := fantasy.Prompt{
273			fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
274		}
275
276		response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
277			Prompt:            prompt,
278			Schema:            schema,
279			SchemaName:        "Book",
280			SchemaDescription: "A book with title, author, genres, and publication year",
281			MaxOutputTokens:   fantasy.Opt(int64(4000)),
282			ProviderOptions:   pair.providerOptions,
283		})
284		require.NoError(t, err, "failed to generate object")
285		require.NotNil(t, response, "response should not be nil")
286		checkResult(t, response.Object, response.RawText, response.Usage)
287	})
288
289	t.Run("complex object streaming", func(t *testing.T) {
290		r := vcr.NewRecorder(t)
291
292		languageModel, err := pair.builder(t, r)
293		require.NoError(t, err, "failed to build language model")
294
295		prompt := fantasy.Prompt{
296			fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
297		}
298
299		stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
300			Prompt:            prompt,
301			Schema:            schema,
302			SchemaName:        "Book",
303			SchemaDescription: "A book with title, author, genres, and publication year",
304			MaxOutputTokens:   fantasy.Opt(int64(4000)),
305			ProviderOptions:   pair.providerOptions,
306		})
307		require.NoError(t, err, "failed to create object stream")
308		require.NotNil(t, stream, "stream should not be nil")
309
310		var lastObject any
311		var rawText string
312		var usage fantasy.Usage
313		var finishReason fantasy.FinishReason
314		objectCount := 0
315
316		for part := range stream {
317			switch part.Type {
318			case fantasy.ObjectStreamPartTypeObject:
319				lastObject = part.Object
320				objectCount++
321			case fantasy.ObjectStreamPartTypeTextDelta:
322				rawText += part.Delta
323			case fantasy.ObjectStreamPartTypeFinish:
324				usage = part.Usage
325				finishReason = part.FinishReason
326			case fantasy.ObjectStreamPartTypeError:
327				t.Fatalf("stream error: %v", part.Error)
328			}
329		}
330
331		require.NotNil(t, lastObject, "should have received at least one object")
332		require.Greater(t, objectCount, 0, "should have received object updates")
333		require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
334
335		// Validate object structure without requiring rawText (may be empty in tool-based mode)
336		require.NotNil(t, lastObject, "object should not be nil")
337		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
338
339		// Validate structure
340		objMap, ok := lastObject.(map[string]any)
341		require.True(t, ok, "object should be a map")
342		require.Contains(t, objMap, "title")
343		require.Contains(t, objMap, "author")
344		require.Contains(t, objMap, "genres")
345		require.Contains(t, objMap, "published_year")
346
347		// Validate title
348		title, ok := objMap["title"].(string)
349		require.True(t, ok, "title should be a string")
350		require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
351
352		// Validate nested author object
353		author, ok := objMap["author"].(map[string]any)
354		require.True(t, ok, "author should be an object")
355		require.Contains(t, author, "name")
356		require.Contains(t, author, "nationality")
357
358		// Validate genres array
359		genres, ok := objMap["genres"].([]any)
360		require.True(t, ok, "genres should be an array")
361		require.Greater(t, len(genres), 0, "genres should have at least one item")
362		for _, genre := range genres {
363			_, ok := genre.(string)
364			require.True(t, ok, "each genre should be a string")
365		}
366
367		// Validate published_year
368		year, ok := objMap["published_year"].(float64)
369		require.True(t, ok, "published_year should be a number")
370		require.Greater(t, year, 1900.0, "published_year should be after 1900")
371	})
372}
373
374// testObjectWithRepair tests object generation with custom repair functionality.
375func testObjectWithRepair(t *testing.T, pairs []builderPair) {
376	for _, pair := range pairs {
377		t.Run(pair.name, func(t *testing.T) {
378			t.Run("object with repair", func(t *testing.T) {
379				r := vcr.NewRecorder(t)
380
381				languageModel, err := pair.builder(t, r)
382				require.NoError(t, err, "failed to build language model")
383
384				minVal := 1.0
385				schema := fantasy.Schema{
386					Type: "object",
387					Properties: map[string]*fantasy.Schema{
388						"count": {
389							Type:        "integer",
390							Description: "A count that must be positive",
391							Minimum:     &minVal,
392						},
393					},
394					Required: []string{"count"},
395				}
396
397				prompt := fantasy.Prompt{
398					fantasy.NewUserMessage("Return a count of 5"),
399				}
400
401				repairFunc := func(ctx context.Context, text string, err error) (string, error) {
402					// Simple repair: if the JSON is malformed, try to fix it
403					// This is a placeholder - real repair would be more sophisticated
404					return text, nil
405				}
406
407				response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
408					Prompt:            prompt,
409					Schema:            schema,
410					SchemaName:        "Count",
411					SchemaDescription: "A simple count object",
412					MaxOutputTokens:   fantasy.Opt(int64(4000)),
413					RepairText:        repairFunc,
414					ProviderOptions:   pair.providerOptions,
415				})
416				require.NoError(t, err, "failed to generate object")
417				require.NotNil(t, response, "response should not be nil")
418				require.NotNil(t, response.Object, "object should not be nil")
419			})
420		})
421	}
422}