object_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strings"
  6	"testing"
  7
  8	"charm.land/fantasy"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// Object generation tests for providers.
 13//
 14// These test functions can be used to test structured object generation
 15// (GenerateObject and StreamObject) for any provider implementation.
 16//
 17// Usage example:
 18//
 19//	func TestMyProviderObjectGeneration(t *testing.T) {
 20//		var pairs []builderPair
 21//		for _, m := range myTestModels {
 22//			pairs = append(pairs, builderPair{m.name, myBuilder(m.model), nil, nil})
 23//		}
 24//		testObjectGeneration(t, pairs)
 25//	}
 26//
 27// The tests cover:
 28// - Simple object generation (flat schema with basic types)
 29// - Complex object generation (nested objects and arrays)
 30// - Streaming object generation (progressive updates)
 31// - Object generation with custom repair functions
 32
 33// testObjectGeneration tests structured object generation for a provider.
 34// It includes both non-streaming (GenerateObject) and streaming (StreamObject) tests.
 35func testObjectGeneration(t *testing.T, pairs []builderPair) {
 36	for _, pair := range pairs {
 37		t.Run(pair.name, func(t *testing.T) {
 38			testSimpleObject(t, pair)
 39			testComplexObject(t, pair)
 40		})
 41	}
 42}
 43
 44func testSimpleObject(t *testing.T, pair builderPair) {
 45	// Define a simple schema for a person object
 46	schema := fantasy.Schema{
 47		Type: "object",
 48		Properties: map[string]*fantasy.Schema{
 49			"name": {
 50				Type:        "string",
 51				Description: "The person's name",
 52			},
 53			"age": {
 54				Type:        "integer",
 55				Description: "The person's age",
 56			},
 57			"city": {
 58				Type:        "string",
 59				Description: "The city where the person lives",
 60			},
 61		},
 62		Required: []string{"name", "age", "city"},
 63	}
 64
 65	checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
 66		require.NotNil(t, obj, "object should not be nil")
 67		require.NotEmpty(t, rawText, "raw text should not be empty")
 68		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
 69
 70		// Validate structure
 71		objMap, ok := obj.(map[string]any)
 72		require.True(t, ok, "object should be a map")
 73		require.Contains(t, objMap, "name")
 74		require.Contains(t, objMap, "age")
 75		require.Contains(t, objMap, "city")
 76
 77		// Validate types
 78		name, ok := objMap["name"].(string)
 79		require.True(t, ok, "name should be a string")
 80		require.NotEmpty(t, name, "name should not be empty")
 81
 82		// Age could be float64 from JSON unmarshaling
 83		age, ok := objMap["age"].(float64)
 84		require.True(t, ok, "age should be a number")
 85		require.Greater(t, age, 0.0, "age should be greater than 0")
 86
 87		city, ok := objMap["city"].(string)
 88		require.True(t, ok, "city should be a string")
 89		require.NotEmpty(t, city, "city should not be empty")
 90	}
 91
 92	t.Run("simple object", func(t *testing.T) {
 93		r := newRecorder(t)
 94
 95		languageModel, err := pair.builder(t, r)
 96		require.NoError(t, err, "failed to build language model")
 97
 98		prompt := fantasy.Prompt{
 99			fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
100		}
101
102		response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
103			Prompt:            prompt,
104			Schema:            schema,
105			SchemaName:        "Person",
106			SchemaDescription: "A person with name, age, and city",
107			MaxOutputTokens:   fantasy.Opt(int64(4000)),
108			ProviderOptions:   pair.providerOptions,
109		})
110		require.NoError(t, err, "failed to generate object")
111		require.NotNil(t, response, "response should not be nil")
112		checkResult(t, response.Object, response.RawText, response.Usage)
113	})
114
115	t.Run("simple object streaming", func(t *testing.T) {
116		r := newRecorder(t)
117
118		languageModel, err := pair.builder(t, r)
119		require.NoError(t, err, "failed to build language model")
120
121		prompt := fantasy.Prompt{
122			fantasy.NewUserMessage("Generate information about a person named Alice who is 30 years old and lives in Paris."),
123		}
124
125		stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
126			Prompt:            prompt,
127			Schema:            schema,
128			SchemaName:        "Person",
129			SchemaDescription: "A person with name, age, and city",
130			MaxOutputTokens:   fantasy.Opt(int64(4000)),
131			ProviderOptions:   pair.providerOptions,
132		})
133		require.NoError(t, err, "failed to create object stream")
134		require.NotNil(t, stream, "stream should not be nil")
135
136		var lastObject any
137		var rawText string
138		var usage fantasy.Usage
139		var finishReason fantasy.FinishReason
140		objectCount := 0
141
142		for part := range stream {
143			switch part.Type {
144			case fantasy.ObjectStreamPartTypeObject:
145				lastObject = part.Object
146				objectCount++
147			case fantasy.ObjectStreamPartTypeTextDelta:
148				rawText += part.Delta
149			case fantasy.ObjectStreamPartTypeFinish:
150				usage = part.Usage
151				finishReason = part.FinishReason
152			case fantasy.ObjectStreamPartTypeError:
153				t.Fatalf("stream error: %v", part.Error)
154			}
155		}
156
157		require.NotNil(t, lastObject, "should have received at least one object")
158		require.Greater(t, objectCount, 0, "should have received object updates")
159		require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
160
161		// Validate object structure without requiring rawText (may be empty in tool-based mode)
162		require.NotNil(t, lastObject, "object should not be nil")
163		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
164
165		// Validate structure
166		objMap, ok := lastObject.(map[string]any)
167		require.True(t, ok, "object should be a map")
168		require.Contains(t, objMap, "name")
169		require.Contains(t, objMap, "age")
170		require.Contains(t, objMap, "city")
171
172		// Validate types
173		name, ok := objMap["name"].(string)
174		require.True(t, ok, "name should be a string")
175		require.NotEmpty(t, name, "name should not be empty")
176
177		// Age could be float64 from JSON unmarshaling
178		age, ok := objMap["age"].(float64)
179		require.True(t, ok, "age should be a number")
180		require.Greater(t, age, 0.0, "age should be greater than 0")
181
182		city, ok := objMap["city"].(string)
183		require.True(t, ok, "city should be a string")
184		require.NotEmpty(t, city, "city should not be empty")
185	})
186}
187
188func testComplexObject(t *testing.T, pair builderPair) {
189	// Define a more complex schema with nested objects and arrays
190	schema := fantasy.Schema{
191		Type: "object",
192		Properties: map[string]*fantasy.Schema{
193			"title": {
194				Type:        "string",
195				Description: "The book title",
196			},
197			"author": {
198				Type: "object",
199				Properties: map[string]*fantasy.Schema{
200					"name": {
201						Type:        "string",
202						Description: "Author's name",
203					},
204					"nationality": {
205						Type:        "string",
206						Description: "Author's nationality",
207					},
208				},
209				Required: []string{"name", "nationality"},
210			},
211			"genres": {
212				Type: "array",
213				Items: &fantasy.Schema{
214					Type: "string",
215				},
216				Description: "List of genres",
217			},
218			"published_year": {
219				Type:        "integer",
220				Description: "Year the book was published",
221			},
222		},
223		Required: []string{"title", "author", "genres", "published_year"},
224	}
225
226	checkResult := func(t *testing.T, obj any, rawText string, usage fantasy.Usage) {
227		require.NotNil(t, obj, "object should not be nil")
228		require.NotEmpty(t, rawText, "raw text should not be empty")
229		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
230
231		// Validate structure
232		objMap, ok := obj.(map[string]any)
233		require.True(t, ok, "object should be a map")
234		require.Contains(t, objMap, "title")
235		require.Contains(t, objMap, "author")
236		require.Contains(t, objMap, "genres")
237		require.Contains(t, objMap, "published_year")
238
239		// Validate title
240		title, ok := objMap["title"].(string)
241		require.True(t, ok, "title should be a string")
242		require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
243
244		// Validate nested author object
245		author, ok := objMap["author"].(map[string]any)
246		require.True(t, ok, "author should be an object")
247		require.Contains(t, author, "name")
248		require.Contains(t, author, "nationality")
249
250		// Validate genres array
251		genres, ok := objMap["genres"].([]any)
252		require.True(t, ok, "genres should be an array")
253		require.Greater(t, len(genres), 0, "genres should have at least one item")
254		for _, genre := range genres {
255			_, ok := genre.(string)
256			require.True(t, ok, "each genre should be a string")
257		}
258
259		// Validate published_year
260		year, ok := objMap["published_year"].(float64)
261		require.True(t, ok, "published_year should be a number")
262		require.Greater(t, year, 1900.0, "published_year should be after 1900")
263	}
264
265	t.Run("complex object", func(t *testing.T) {
266		r := newRecorder(t)
267
268		languageModel, err := pair.builder(t, r)
269		require.NoError(t, err, "failed to build language model")
270
271		prompt := fantasy.Prompt{
272			fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
273		}
274
275		response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
276			Prompt:            prompt,
277			Schema:            schema,
278			SchemaName:        "Book",
279			SchemaDescription: "A book with title, author, genres, and publication year",
280			MaxOutputTokens:   fantasy.Opt(int64(4000)),
281			ProviderOptions:   pair.providerOptions,
282		})
283		require.NoError(t, err, "failed to generate object")
284		require.NotNil(t, response, "response should not be nil")
285		checkResult(t, response.Object, response.RawText, response.Usage)
286	})
287
288	t.Run("complex object streaming", func(t *testing.T) {
289		r := newRecorder(t)
290
291		languageModel, err := pair.builder(t, r)
292		require.NoError(t, err, "failed to build language model")
293
294		prompt := fantasy.Prompt{
295			fantasy.NewUserMessage("Generate information about 'The Lord of the Rings' book by J.R.R. Tolkien, including genres like fantasy and adventure, and its publication year (1954)."),
296		}
297
298		stream, err := languageModel.StreamObject(t.Context(), fantasy.ObjectCall{
299			Prompt:            prompt,
300			Schema:            schema,
301			SchemaName:        "Book",
302			SchemaDescription: "A book with title, author, genres, and publication year",
303			MaxOutputTokens:   fantasy.Opt(int64(4000)),
304			ProviderOptions:   pair.providerOptions,
305		})
306		require.NoError(t, err, "failed to create object stream")
307		require.NotNil(t, stream, "stream should not be nil")
308
309		var lastObject any
310		var rawText string
311		var usage fantasy.Usage
312		var finishReason fantasy.FinishReason
313		objectCount := 0
314
315		for part := range stream {
316			switch part.Type {
317			case fantasy.ObjectStreamPartTypeObject:
318				lastObject = part.Object
319				objectCount++
320			case fantasy.ObjectStreamPartTypeTextDelta:
321				rawText += part.Delta
322			case fantasy.ObjectStreamPartTypeFinish:
323				usage = part.Usage
324				finishReason = part.FinishReason
325			case fantasy.ObjectStreamPartTypeError:
326				t.Fatalf("stream error: %v", part.Error)
327			}
328		}
329
330		require.NotNil(t, lastObject, "should have received at least one object")
331		require.Greater(t, objectCount, 0, "should have received object updates")
332		require.NotEqual(t, fantasy.FinishReasonUnknown, finishReason, "should have a finish reason")
333
334		// Validate object structure without requiring rawText (may be empty in tool-based mode)
335		require.NotNil(t, lastObject, "object should not be nil")
336		require.Greater(t, usage.TotalTokens, int64(0), "usage should be tracked")
337
338		// Validate structure
339		objMap, ok := lastObject.(map[string]any)
340		require.True(t, ok, "object should be a map")
341		require.Contains(t, objMap, "title")
342		require.Contains(t, objMap, "author")
343		require.Contains(t, objMap, "genres")
344		require.Contains(t, objMap, "published_year")
345
346		// Validate title
347		title, ok := objMap["title"].(string)
348		require.True(t, ok, "title should be a string")
349		require.True(t, strings.Contains(strings.ToLower(title), "rings"), "title should contain 'rings'")
350
351		// Validate nested author object
352		author, ok := objMap["author"].(map[string]any)
353		require.True(t, ok, "author should be an object")
354		require.Contains(t, author, "name")
355		require.Contains(t, author, "nationality")
356
357		// Validate genres array
358		genres, ok := objMap["genres"].([]any)
359		require.True(t, ok, "genres should be an array")
360		require.Greater(t, len(genres), 0, "genres should have at least one item")
361		for _, genre := range genres {
362			_, ok := genre.(string)
363			require.True(t, ok, "each genre should be a string")
364		}
365
366		// Validate published_year
367		year, ok := objMap["published_year"].(float64)
368		require.True(t, ok, "published_year should be a number")
369		require.Greater(t, year, 1900.0, "published_year should be after 1900")
370	})
371}
372
373// testObjectWithRepair tests object generation with custom repair functionality.
374func testObjectWithRepair(t *testing.T, pairs []builderPair) {
375	for _, pair := range pairs {
376		t.Run(pair.name, func(t *testing.T) {
377			t.Run("object with repair", func(t *testing.T) {
378				r := newRecorder(t)
379
380				languageModel, err := pair.builder(t, r)
381				require.NoError(t, err, "failed to build language model")
382
383				minVal := 1.0
384				schema := fantasy.Schema{
385					Type: "object",
386					Properties: map[string]*fantasy.Schema{
387						"count": {
388							Type:        "integer",
389							Description: "A count that must be positive",
390							Minimum:     &minVal,
391						},
392					},
393					Required: []string{"count"},
394				}
395
396				prompt := fantasy.Prompt{
397					fantasy.NewUserMessage("Return a count of 5"),
398				}
399
400				repairFunc := func(ctx context.Context, text string, err error) (string, error) {
401					// Simple repair: if the JSON is malformed, try to fix it
402					// This is a placeholder - real repair would be more sophisticated
403					return text, nil
404				}
405
406				response, err := languageModel.GenerateObject(t.Context(), fantasy.ObjectCall{
407					Prompt:            prompt,
408					Schema:            schema,
409					SchemaName:        "Count",
410					SchemaDescription: "A simple count object",
411					MaxOutputTokens:   fantasy.Opt(int64(4000)),
412					RepairText:        repairFunc,
413					ProviderOptions:   pair.providerOptions,
414				})
415				require.NoError(t, err, "failed to generate object")
416				require.NotNil(t, response, "response should not be nil")
417				require.NotNil(t, response.Object, "object should not be nil")
418			})
419		})
420	}
421}