object.go

  1// Package object provides utilities for generating structured objects with automatic schema generation.
  2// It simplifies working with typed structured outputs by handling schema reflection and unmarshaling.
  3package object
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"reflect"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/schema"
 13)
 14
 15// Generate generates a structured object that matches the given type T.
 16// The schema is automatically generated from T using reflection.
 17//
 18// Example:
 19//
 20//	type Recipe struct {
 21//	    Name        string   `json:"name"`
 22//	    Ingredients []string `json:"ingredients"`
 23//	}
 24//
 25//	result, err := object.Generate[Recipe](ctx, model, fantasy.ObjectCall{
 26//	    Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
 27//	})
 28func Generate[T any](
 29	ctx context.Context,
 30	model fantasy.LanguageModel,
 31	opts fantasy.ObjectCall,
 32) (*fantasy.ObjectResult[T], error) {
 33	var zero T
 34	s := schema.Generate(reflect.TypeOf(zero))
 35	opts.Schema = s
 36
 37	resp, err := model.GenerateObject(ctx, opts)
 38	if err != nil {
 39		return nil, err
 40	}
 41
 42	var result T
 43	if err := unmarshal(resp.Object, &result); err != nil {
 44		return nil, fmt.Errorf("failed to unmarshal to %T: %w", result, err)
 45	}
 46
 47	return &fantasy.ObjectResult[T]{
 48		Object:           result,
 49		RawText:          resp.RawText,
 50		Usage:            resp.Usage,
 51		FinishReason:     resp.FinishReason,
 52		Warnings:         resp.Warnings,
 53		ProviderMetadata: resp.ProviderMetadata,
 54	}, nil
 55}
 56
 57// Stream streams a structured object that matches the given type T.
 58// Returns a StreamObjectResult[T] with progressive updates and deduplication.
 59//
 60// Example:
 61//
 62//	stream, err := object.Stream[Recipe](ctx, model, fantasy.ObjectCall{
 63//	    Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
 64//	})
 65//
 66//	for partial := range stream.PartialObjectStream() {
 67//	    fmt.Printf("Progress: %s\n", partial.Name)
 68//	}
 69//
 70//	result, err := stream.Object()  // Wait for final result
 71func Stream[T any](
 72	ctx context.Context,
 73	model fantasy.LanguageModel,
 74	opts fantasy.ObjectCall,
 75) (*fantasy.StreamObjectResult[T], error) {
 76	var zero T
 77	s := schema.Generate(reflect.TypeOf(zero))
 78	opts.Schema = s
 79
 80	stream, err := model.StreamObject(ctx, opts)
 81	if err != nil {
 82		return nil, err
 83	}
 84
 85	return fantasy.NewStreamObjectResult[T](ctx, stream), nil
 86}
 87
 88// GenerateWithTool is a helper for providers without native JSON mode.
 89// It converts the schema to a tool definition, forces the model to call it,
 90// and extracts the tool's input as the structured output.
 91func GenerateWithTool(
 92	ctx context.Context,
 93	model fantasy.LanguageModel,
 94	call fantasy.ObjectCall,
 95) (*fantasy.ObjectResponse, error) {
 96	toolName := call.SchemaName
 97	if toolName == "" {
 98		toolName = "generate_object"
 99	}
100
101	toolDescription := call.SchemaDescription
102	if toolDescription == "" {
103		toolDescription = "Generate a structured object matching the schema"
104	}
105
106	tool := fantasy.FunctionTool{
107		Name:        toolName,
108		Description: toolDescription,
109		InputSchema: schema.ToMap(call.Schema),
110	}
111
112	toolChoice := fantasy.SpecificToolChoice(tool.Name)
113	resp, err := model.Generate(ctx, fantasy.Call{
114		Prompt:           call.Prompt,
115		Tools:            []fantasy.Tool{tool},
116		ToolChoice:       &toolChoice,
117		MaxOutputTokens:  call.MaxOutputTokens,
118		Temperature:      call.Temperature,
119		TopP:             call.TopP,
120		TopK:             call.TopK,
121		PresencePenalty:  call.PresencePenalty,
122		FrequencyPenalty: call.FrequencyPenalty,
123		UserAgent:        call.UserAgent,
124		ProviderOptions:  call.ProviderOptions,
125	})
126	if err != nil {
127		return nil, fmt.Errorf("tool-based generation failed: %w", err)
128	}
129
130	toolCalls := resp.Content.ToolCalls()
131	if len(toolCalls) == 0 {
132		return nil, &fantasy.NoObjectGeneratedError{
133			RawText:      resp.Content.Text(),
134			ParseError:   fmt.Errorf("no tool call generated"),
135			Usage:        resp.Usage,
136			FinishReason: resp.FinishReason,
137		}
138	}
139
140	toolCall := toolCalls[0]
141
142	var obj any
143	if call.RepairText != nil {
144		obj, err = schema.ParseAndValidateWithRepair(ctx, toolCall.Input, call.Schema, call.RepairText)
145	} else {
146		obj, err = schema.ParseAndValidate(toolCall.Input, call.Schema)
147	}
148
149	if err != nil {
150		if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
151			nogErr.Usage = resp.Usage
152			nogErr.FinishReason = resp.FinishReason
153		}
154		return nil, err
155	}
156
157	return &fantasy.ObjectResponse{
158		Object:           obj,
159		RawText:          toolCall.Input,
160		Usage:            resp.Usage,
161		FinishReason:     resp.FinishReason,
162		Warnings:         resp.Warnings,
163		ProviderMetadata: resp.ProviderMetadata,
164	}, nil
165}
166
167// GenerateWithText is a helper for providers without tool or JSON mode support.
168// It adds the schema to the system prompt and parses the text response as JSON.
169// This is a fallback for older models or simple providers.
170func GenerateWithText(
171	ctx context.Context,
172	model fantasy.LanguageModel,
173	call fantasy.ObjectCall,
174) (*fantasy.ObjectResponse, error) {
175	jsonSchemaBytes, err := json.Marshal(call.Schema)
176	if err != nil {
177		return nil, fmt.Errorf("failed to marshal schema: %w", err)
178	}
179
180	schemaInstruction := fmt.Sprintf(
181		"You must respond with valid JSON that matches this schema: %s\n"+
182			"Respond ONLY with the JSON object, no additional text or explanation.",
183		string(jsonSchemaBytes),
184	)
185
186	enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
187
188	hasSystem := false
189	for _, msg := range call.Prompt {
190		if msg.Role == fantasy.MessageRoleSystem {
191			hasSystem = true
192			existingText := ""
193			if len(msg.Content) > 0 {
194				if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
195					existingText = textPart.Text
196				}
197			}
198			enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
199		} else {
200			enhancedPrompt = append(enhancedPrompt, msg)
201		}
202	}
203
204	if !hasSystem {
205		enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
206	}
207
208	resp, err := model.Generate(ctx, fantasy.Call{
209		Prompt:           enhancedPrompt,
210		MaxOutputTokens:  call.MaxOutputTokens,
211		Temperature:      call.Temperature,
212		TopP:             call.TopP,
213		TopK:             call.TopK,
214		PresencePenalty:  call.PresencePenalty,
215		FrequencyPenalty: call.FrequencyPenalty,
216		UserAgent:        call.UserAgent,
217		ProviderOptions:  call.ProviderOptions,
218	})
219	if err != nil {
220		return nil, fmt.Errorf("text-based generation failed: %w", err)
221	}
222
223	textContent := resp.Content.Text()
224	if textContent == "" {
225		return nil, &fantasy.NoObjectGeneratedError{
226			RawText:      "",
227			ParseError:   fmt.Errorf("no text content in response"),
228			Usage:        resp.Usage,
229			FinishReason: resp.FinishReason,
230		}
231	}
232
233	var obj any
234	if call.RepairText != nil {
235		obj, err = schema.ParseAndValidateWithRepair(ctx, textContent, call.Schema, call.RepairText)
236	} else {
237		obj, err = schema.ParseAndValidate(textContent, call.Schema)
238	}
239
240	if err != nil {
241		if nogErr, ok := err.(*schema.ParseError); ok {
242			return nil, &fantasy.NoObjectGeneratedError{
243				RawText:         nogErr.RawText,
244				ParseError:      nogErr.ParseError,
245				ValidationError: nogErr.ValidationError,
246				Usage:           resp.Usage,
247				FinishReason:    resp.FinishReason,
248			}
249		}
250		return nil, err
251	}
252
253	return &fantasy.ObjectResponse{
254		Object:           obj,
255		RawText:          textContent,
256		Usage:            resp.Usage,
257		FinishReason:     resp.FinishReason,
258		Warnings:         resp.Warnings,
259		ProviderMetadata: resp.ProviderMetadata,
260	}, nil
261}
262
263// StreamWithTool is a helper for providers without native JSON streaming.
264// It uses streaming tool calls to extract and parse the structured output progressively.
265func StreamWithTool(
266	ctx context.Context,
267	model fantasy.LanguageModel,
268	call fantasy.ObjectCall,
269) (fantasy.ObjectStreamResponse, error) {
270	// Create a tool from the schema
271	toolName := call.SchemaName
272	if toolName == "" {
273		toolName = "generate_object"
274	}
275
276	toolDescription := call.SchemaDescription
277	if toolDescription == "" {
278		toolDescription = "Generate a structured object matching the schema"
279	}
280
281	tool := fantasy.FunctionTool{
282		Name:        toolName,
283		Description: toolDescription,
284		InputSchema: schema.ToMap(call.Schema),
285	}
286
287	// Make a streaming Generate call with forced tool choice
288	toolChoice := fantasy.SpecificToolChoice(tool.Name)
289	stream, err := model.Stream(ctx, fantasy.Call{
290		Prompt:           call.Prompt,
291		Tools:            []fantasy.Tool{tool},
292		ToolChoice:       &toolChoice,
293		MaxOutputTokens:  call.MaxOutputTokens,
294		Temperature:      call.Temperature,
295		TopP:             call.TopP,
296		TopK:             call.TopK,
297		PresencePenalty:  call.PresencePenalty,
298		FrequencyPenalty: call.FrequencyPenalty,
299		UserAgent:        call.UserAgent,
300		ProviderOptions:  call.ProviderOptions,
301	})
302	if err != nil {
303		return nil, fmt.Errorf("tool-based streaming failed: %w", err)
304	}
305
306	// Convert the text stream to object stream parts
307	return func(yield func(fantasy.ObjectStreamPart) bool) {
308		var accumulated string
309		var lastParsedObject any
310		var usage fantasy.Usage
311		var finishReason fantasy.FinishReason
312		var warnings []fantasy.CallWarning
313		var providerMetadata fantasy.ProviderMetadata
314		var streamErr error
315
316		for part := range stream {
317			switch part.Type {
318			case fantasy.StreamPartTypeTextDelta:
319				accumulated += part.Delta
320
321				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
322
323				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
324					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
325						if !reflect.DeepEqual(obj, lastParsedObject) {
326							if !yield(fantasy.ObjectStreamPart{
327								Type:   fantasy.ObjectStreamPartTypeObject,
328								Object: obj,
329							}) {
330								return
331							}
332							lastParsedObject = obj
333						}
334					}
335				}
336
337				if state == schema.ParseStateFailed && call.RepairText != nil {
338					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
339					if repairErr == nil {
340						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
341						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
342							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
343							if !reflect.DeepEqual(obj2, lastParsedObject) {
344								if !yield(fantasy.ObjectStreamPart{
345									Type:   fantasy.ObjectStreamPartTypeObject,
346									Object: obj2,
347								}) {
348									return
349								}
350								lastParsedObject = obj2
351							}
352						}
353					}
354				}
355
356			case fantasy.StreamPartTypeToolInputDelta:
357				accumulated += part.Delta
358
359				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
360				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
361					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
362						if !reflect.DeepEqual(obj, lastParsedObject) {
363							if !yield(fantasy.ObjectStreamPart{
364								Type:   fantasy.ObjectStreamPartTypeObject,
365								Object: obj,
366							}) {
367								return
368							}
369							lastParsedObject = obj
370						}
371					}
372				}
373
374				if state == schema.ParseStateFailed && call.RepairText != nil {
375					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
376					if repairErr == nil {
377						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
378						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
379							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
380							if !reflect.DeepEqual(obj2, lastParsedObject) {
381								if !yield(fantasy.ObjectStreamPart{
382									Type:   fantasy.ObjectStreamPartTypeObject,
383									Object: obj2,
384								}) {
385									return
386								}
387								lastParsedObject = obj2
388							}
389						}
390					}
391				}
392
393			case fantasy.StreamPartTypeToolCall:
394				toolInput := part.ToolCallInput
395
396				var obj any
397				var err error
398				if call.RepairText != nil {
399					obj, err = schema.ParseAndValidateWithRepair(ctx, toolInput, call.Schema, call.RepairText)
400				} else {
401					obj, err = schema.ParseAndValidate(toolInput, call.Schema)
402				}
403
404				if err == nil {
405					if !reflect.DeepEqual(obj, lastParsedObject) {
406						if !yield(fantasy.ObjectStreamPart{
407							Type:   fantasy.ObjectStreamPartTypeObject,
408							Object: obj,
409						}) {
410							return
411						}
412						lastParsedObject = obj
413					}
414				}
415
416			case fantasy.StreamPartTypeError:
417				streamErr = part.Error
418				if !yield(fantasy.ObjectStreamPart{
419					Type:  fantasy.ObjectStreamPartTypeError,
420					Error: part.Error,
421				}) {
422					return
423				}
424
425			case fantasy.StreamPartTypeFinish:
426				usage = part.Usage
427				finishReason = part.FinishReason
428
429			case fantasy.StreamPartTypeWarnings:
430				warnings = part.Warnings
431			}
432
433			if len(part.ProviderMetadata) > 0 {
434				providerMetadata = part.ProviderMetadata
435			}
436		}
437
438		if streamErr == nil && lastParsedObject != nil {
439			yield(fantasy.ObjectStreamPart{
440				Type:             fantasy.ObjectStreamPartTypeFinish,
441				Usage:            usage,
442				FinishReason:     finishReason,
443				Warnings:         warnings,
444				ProviderMetadata: providerMetadata,
445			})
446		} else if streamErr == nil && lastParsedObject == nil {
447			yield(fantasy.ObjectStreamPart{
448				Type: fantasy.ObjectStreamPartTypeError,
449				Error: &fantasy.NoObjectGeneratedError{
450					RawText:      accumulated,
451					ParseError:   fmt.Errorf("no valid object generated in stream"),
452					Usage:        usage,
453					FinishReason: finishReason,
454				},
455			})
456		}
457	}, nil
458}
459
460// StreamWithText is a helper for providers without tool or JSON streaming support.
461// It adds the schema to the system prompt and parses the streamed text as JSON progressively.
462func StreamWithText(
463	ctx context.Context,
464	model fantasy.LanguageModel,
465	call fantasy.ObjectCall,
466) (fantasy.ObjectStreamResponse, error) {
467	jsonSchemaMap := schema.ToMap(call.Schema)
468	jsonSchemaBytes, err := json.Marshal(jsonSchemaMap)
469	if err != nil {
470		return nil, fmt.Errorf("failed to marshal schema: %w", err)
471	}
472
473	schemaInstruction := fmt.Sprintf(
474		"You must respond with valid JSON that matches this schema: %s\n"+
475			"Respond ONLY with the JSON object, no additional text or explanation.",
476		string(jsonSchemaBytes),
477	)
478
479	enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
480
481	hasSystem := false
482	for _, msg := range call.Prompt {
483		if msg.Role == fantasy.MessageRoleSystem {
484			hasSystem = true
485			existingText := ""
486			if len(msg.Content) > 0 {
487				if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
488					existingText = textPart.Text
489				}
490			}
491			enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
492		} else {
493			enhancedPrompt = append(enhancedPrompt, msg)
494		}
495	}
496
497	if !hasSystem {
498		enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
499	}
500
501	stream, err := model.Stream(ctx, fantasy.Call{
502		Prompt:           enhancedPrompt,
503		MaxOutputTokens:  call.MaxOutputTokens,
504		Temperature:      call.Temperature,
505		TopP:             call.TopP,
506		TopK:             call.TopK,
507		PresencePenalty:  call.PresencePenalty,
508		FrequencyPenalty: call.FrequencyPenalty,
509		UserAgent:        call.UserAgent,
510		ProviderOptions:  call.ProviderOptions,
511	})
512	if err != nil {
513		return nil, fmt.Errorf("text-based streaming failed: %w", err)
514	}
515
516	return func(yield func(fantasy.ObjectStreamPart) bool) {
517		var accumulated string
518		var lastParsedObject any
519		var usage fantasy.Usage
520		var finishReason fantasy.FinishReason
521		var warnings []fantasy.CallWarning
522		var providerMetadata fantasy.ProviderMetadata
523		var streamErr error
524
525		for part := range stream {
526			switch part.Type {
527			case fantasy.StreamPartTypeTextDelta:
528				accumulated += part.Delta
529
530				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
531
532				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
533					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
534						if !reflect.DeepEqual(obj, lastParsedObject) {
535							if !yield(fantasy.ObjectStreamPart{
536								Type:   fantasy.ObjectStreamPartTypeObject,
537								Object: obj,
538							}) {
539								return
540							}
541							lastParsedObject = obj
542						}
543					}
544				}
545
546				if state == schema.ParseStateFailed && call.RepairText != nil {
547					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
548					if repairErr == nil {
549						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
550						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
551							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
552							if !reflect.DeepEqual(obj2, lastParsedObject) {
553								if !yield(fantasy.ObjectStreamPart{
554									Type:   fantasy.ObjectStreamPartTypeObject,
555									Object: obj2,
556								}) {
557									return
558								}
559								lastParsedObject = obj2
560							}
561						}
562					}
563				}
564
565			case fantasy.StreamPartTypeError:
566				streamErr = part.Error
567				if !yield(fantasy.ObjectStreamPart{
568					Type:  fantasy.ObjectStreamPartTypeError,
569					Error: part.Error,
570				}) {
571					return
572				}
573
574			case fantasy.StreamPartTypeFinish:
575				usage = part.Usage
576				finishReason = part.FinishReason
577
578			case fantasy.StreamPartTypeWarnings:
579				warnings = part.Warnings
580			}
581
582			if len(part.ProviderMetadata) > 0 {
583				providerMetadata = part.ProviderMetadata
584			}
585		}
586
587		if streamErr == nil && lastParsedObject != nil {
588			yield(fantasy.ObjectStreamPart{
589				Type:             fantasy.ObjectStreamPartTypeFinish,
590				Usage:            usage,
591				FinishReason:     finishReason,
592				Warnings:         warnings,
593				ProviderMetadata: providerMetadata,
594			})
595		} else if streamErr == nil && lastParsedObject == nil {
596			yield(fantasy.ObjectStreamPart{
597				Type: fantasy.ObjectStreamPartTypeError,
598				Error: &fantasy.NoObjectGeneratedError{
599					RawText:      accumulated,
600					ParseError:   fmt.Errorf("no valid object generated in stream"),
601					Usage:        usage,
602					FinishReason: finishReason,
603				},
604			})
605		}
606	}, nil
607}
608
609func unmarshal(obj any, target any) error {
610	jsonBytes, err := json.Marshal(obj)
611	if err != nil {
612		return fmt.Errorf("failed to marshal object: %w", err)
613	}
614
615	if err := json.Unmarshal(jsonBytes, target); err != nil {
616		return fmt.Errorf("failed to unmarshal into target type: %w", err)
617	}
618
619	return nil
620}