object.go

  1// Package object provides utilities for generating structured objects with automatic schema generation.
  2// It simplifies working with typed structured outputs by handling schema reflection and unmarshaling.
  3package object
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"reflect"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/schema"
 13)
 14
 15// Generate generates a structured object that matches the given type T.
 16// The schema is automatically generated from T using reflection.
 17//
 18// Example:
 19//
 20//	type Recipe struct {
 21//	    Name        string   `json:"name"`
 22//	    Ingredients []string `json:"ingredients"`
 23//	}
 24//
 25//	result, err := object.Generate[Recipe](ctx, model, fantasy.ObjectCall{
 26//	    Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
 27//	})
 28func Generate[T any](
 29	ctx context.Context,
 30	model fantasy.LanguageModel,
 31	opts fantasy.ObjectCall,
 32) (*fantasy.ObjectResult[T], error) {
 33	var zero T
 34	s := schema.Generate(reflect.TypeOf(zero))
 35	opts.Schema = s
 36
 37	resp, err := model.GenerateObject(ctx, opts)
 38	if err != nil {
 39		return nil, err
 40	}
 41
 42	var result T
 43	if err := unmarshal(resp.Object, &result); err != nil {
 44		return nil, fmt.Errorf("failed to unmarshal to %T: %w", result, err)
 45	}
 46
 47	return &fantasy.ObjectResult[T]{
 48		Object:           result,
 49		RawText:          resp.RawText,
 50		Usage:            resp.Usage,
 51		FinishReason:     resp.FinishReason,
 52		Warnings:         resp.Warnings,
 53		ProviderMetadata: resp.ProviderMetadata,
 54	}, nil
 55}
 56
 57// Stream streams a structured object that matches the given type T.
 58// Returns a StreamObjectResult[T] with progressive updates and deduplication.
 59//
 60// Example:
 61//
 62//	stream, err := object.Stream[Recipe](ctx, model, fantasy.ObjectCall{
 63//	    Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
 64//	})
 65//
 66//	for partial := range stream.PartialObjectStream() {
 67//	    fmt.Printf("Progress: %s\n", partial.Name)
 68//	}
 69//
 70//	result, err := stream.Object()  // Wait for final result
 71func Stream[T any](
 72	ctx context.Context,
 73	model fantasy.LanguageModel,
 74	opts fantasy.ObjectCall,
 75) (*fantasy.StreamObjectResult[T], error) {
 76	var zero T
 77	s := schema.Generate(reflect.TypeOf(zero))
 78	opts.Schema = s
 79
 80	stream, err := model.StreamObject(ctx, opts)
 81	if err != nil {
 82		return nil, err
 83	}
 84
 85	return fantasy.NewStreamObjectResult[T](ctx, stream), nil
 86}
 87
 88// GenerateWithTool is a helper for providers without native JSON mode.
 89// It converts the schema to a tool definition, forces the model to call it,
 90// and extracts the tool's input as the structured output.
 91func GenerateWithTool(
 92	ctx context.Context,
 93	model fantasy.LanguageModel,
 94	call fantasy.ObjectCall,
 95) (*fantasy.ObjectResponse, error) {
 96	toolName := call.SchemaName
 97	if toolName == "" {
 98		toolName = "generate_object"
 99	}
100
101	toolDescription := call.SchemaDescription
102	if toolDescription == "" {
103		toolDescription = "Generate a structured object matching the schema"
104	}
105
106	tool := fantasy.FunctionTool{
107		Name:        toolName,
108		Description: toolDescription,
109		InputSchema: schema.ToMap(call.Schema),
110	}
111
112	toolChoice := fantasy.SpecificToolChoice(tool.Name)
113	resp, err := model.Generate(ctx, fantasy.Call{
114		Prompt:           call.Prompt,
115		Tools:            []fantasy.Tool{tool},
116		ToolChoice:       &toolChoice,
117		MaxOutputTokens:  call.MaxOutputTokens,
118		Temperature:      call.Temperature,
119		TopP:             call.TopP,
120		TopK:             call.TopK,
121		PresencePenalty:  call.PresencePenalty,
122		FrequencyPenalty: call.FrequencyPenalty,
123		ProviderOptions:  call.ProviderOptions,
124	})
125	if err != nil {
126		return nil, fmt.Errorf("tool-based generation failed: %w", err)
127	}
128
129	toolCalls := resp.Content.ToolCalls()
130	if len(toolCalls) == 0 {
131		return nil, &fantasy.NoObjectGeneratedError{
132			RawText:      resp.Content.Text(),
133			ParseError:   fmt.Errorf("no tool call generated"),
134			Usage:        resp.Usage,
135			FinishReason: resp.FinishReason,
136		}
137	}
138
139	toolCall := toolCalls[0]
140
141	var obj any
142	if call.RepairText != nil {
143		obj, err = schema.ParseAndValidateWithRepair(ctx, toolCall.Input, call.Schema, call.RepairText)
144	} else {
145		obj, err = schema.ParseAndValidate(toolCall.Input, call.Schema)
146	}
147
148	if err != nil {
149		if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
150			nogErr.Usage = resp.Usage
151			nogErr.FinishReason = resp.FinishReason
152		}
153		return nil, err
154	}
155
156	return &fantasy.ObjectResponse{
157		Object:           obj,
158		RawText:          toolCall.Input,
159		Usage:            resp.Usage,
160		FinishReason:     resp.FinishReason,
161		Warnings:         resp.Warnings,
162		ProviderMetadata: resp.ProviderMetadata,
163	}, nil
164}
165
166// GenerateWithText is a helper for providers without tool or JSON mode support.
167// It adds the schema to the system prompt and parses the text response as JSON.
168// This is a fallback for older models or simple providers.
169func GenerateWithText(
170	ctx context.Context,
171	model fantasy.LanguageModel,
172	call fantasy.ObjectCall,
173) (*fantasy.ObjectResponse, error) {
174	jsonSchemaBytes, err := json.Marshal(call.Schema)
175	if err != nil {
176		return nil, fmt.Errorf("failed to marshal schema: %w", err)
177	}
178
179	schemaInstruction := fmt.Sprintf(
180		"You must respond with valid JSON that matches this schema: %s\n"+
181			"Respond ONLY with the JSON object, no additional text or explanation.",
182		string(jsonSchemaBytes),
183	)
184
185	enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
186
187	hasSystem := false
188	for _, msg := range call.Prompt {
189		if msg.Role == fantasy.MessageRoleSystem {
190			hasSystem = true
191			existingText := ""
192			if len(msg.Content) > 0 {
193				if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
194					existingText = textPart.Text
195				}
196			}
197			enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
198		} else {
199			enhancedPrompt = append(enhancedPrompt, msg)
200		}
201	}
202
203	if !hasSystem {
204		enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
205	}
206
207	resp, err := model.Generate(ctx, fantasy.Call{
208		Prompt:           enhancedPrompt,
209		MaxOutputTokens:  call.MaxOutputTokens,
210		Temperature:      call.Temperature,
211		TopP:             call.TopP,
212		TopK:             call.TopK,
213		PresencePenalty:  call.PresencePenalty,
214		FrequencyPenalty: call.FrequencyPenalty,
215		ProviderOptions:  call.ProviderOptions,
216	})
217	if err != nil {
218		return nil, fmt.Errorf("text-based generation failed: %w", err)
219	}
220
221	textContent := resp.Content.Text()
222	if textContent == "" {
223		return nil, &fantasy.NoObjectGeneratedError{
224			RawText:      "",
225			ParseError:   fmt.Errorf("no text content in response"),
226			Usage:        resp.Usage,
227			FinishReason: resp.FinishReason,
228		}
229	}
230
231	var obj any
232	if call.RepairText != nil {
233		obj, err = schema.ParseAndValidateWithRepair(ctx, textContent, call.Schema, call.RepairText)
234	} else {
235		obj, err = schema.ParseAndValidate(textContent, call.Schema)
236	}
237
238	if err != nil {
239		if nogErr, ok := err.(*schema.ParseError); ok {
240			return nil, &fantasy.NoObjectGeneratedError{
241				RawText:         nogErr.RawText,
242				ParseError:      nogErr.ParseError,
243				ValidationError: nogErr.ValidationError,
244				Usage:           resp.Usage,
245				FinishReason:    resp.FinishReason,
246			}
247		}
248		return nil, err
249	}
250
251	return &fantasy.ObjectResponse{
252		Object:           obj,
253		RawText:          textContent,
254		Usage:            resp.Usage,
255		FinishReason:     resp.FinishReason,
256		Warnings:         resp.Warnings,
257		ProviderMetadata: resp.ProviderMetadata,
258	}, nil
259}
260
261// StreamWithTool is a helper for providers without native JSON streaming.
262// It uses streaming tool calls to extract and parse the structured output progressively.
263func StreamWithTool(
264	ctx context.Context,
265	model fantasy.LanguageModel,
266	call fantasy.ObjectCall,
267) (fantasy.ObjectStreamResponse, error) {
268	// Create a tool from the schema
269	toolName := call.SchemaName
270	if toolName == "" {
271		toolName = "generate_object"
272	}
273
274	toolDescription := call.SchemaDescription
275	if toolDescription == "" {
276		toolDescription = "Generate a structured object matching the schema"
277	}
278
279	tool := fantasy.FunctionTool{
280		Name:        toolName,
281		Description: toolDescription,
282		InputSchema: schema.ToMap(call.Schema),
283	}
284
285	// Make a streaming Generate call with forced tool choice
286	toolChoice := fantasy.SpecificToolChoice(tool.Name)
287	stream, err := model.Stream(ctx, fantasy.Call{
288		Prompt:           call.Prompt,
289		Tools:            []fantasy.Tool{tool},
290		ToolChoice:       &toolChoice,
291		MaxOutputTokens:  call.MaxOutputTokens,
292		Temperature:      call.Temperature,
293		TopP:             call.TopP,
294		TopK:             call.TopK,
295		PresencePenalty:  call.PresencePenalty,
296		FrequencyPenalty: call.FrequencyPenalty,
297		ProviderOptions:  call.ProviderOptions,
298	})
299	if err != nil {
300		return nil, fmt.Errorf("tool-based streaming failed: %w", err)
301	}
302
303	// Convert the text stream to object stream parts
304	return func(yield func(fantasy.ObjectStreamPart) bool) {
305		var accumulated string
306		var lastParsedObject any
307		var usage fantasy.Usage
308		var finishReason fantasy.FinishReason
309		var warnings []fantasy.CallWarning
310		var providerMetadata fantasy.ProviderMetadata
311		var streamErr error
312
313		for part := range stream {
314			switch part.Type {
315			case fantasy.StreamPartTypeTextDelta:
316				accumulated += part.Delta
317
318				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
319
320				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
321					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
322						if !reflect.DeepEqual(obj, lastParsedObject) {
323							if !yield(fantasy.ObjectStreamPart{
324								Type:   fantasy.ObjectStreamPartTypeObject,
325								Object: obj,
326							}) {
327								return
328							}
329							lastParsedObject = obj
330						}
331					}
332				}
333
334				if state == schema.ParseStateFailed && call.RepairText != nil {
335					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
336					if repairErr == nil {
337						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
338						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
339							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
340							if !reflect.DeepEqual(obj2, lastParsedObject) {
341								if !yield(fantasy.ObjectStreamPart{
342									Type:   fantasy.ObjectStreamPartTypeObject,
343									Object: obj2,
344								}) {
345									return
346								}
347								lastParsedObject = obj2
348							}
349						}
350					}
351				}
352
353			case fantasy.StreamPartTypeToolInputDelta:
354				accumulated += part.Delta
355
356				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
357				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
358					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
359						if !reflect.DeepEqual(obj, lastParsedObject) {
360							if !yield(fantasy.ObjectStreamPart{
361								Type:   fantasy.ObjectStreamPartTypeObject,
362								Object: obj,
363							}) {
364								return
365							}
366							lastParsedObject = obj
367						}
368					}
369				}
370
371				if state == schema.ParseStateFailed && call.RepairText != nil {
372					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
373					if repairErr == nil {
374						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
375						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
376							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
377							if !reflect.DeepEqual(obj2, lastParsedObject) {
378								if !yield(fantasy.ObjectStreamPart{
379									Type:   fantasy.ObjectStreamPartTypeObject,
380									Object: obj2,
381								}) {
382									return
383								}
384								lastParsedObject = obj2
385							}
386						}
387					}
388				}
389
390			case fantasy.StreamPartTypeToolCall:
391				toolInput := part.ToolCallInput
392
393				var obj any
394				var err error
395				if call.RepairText != nil {
396					obj, err = schema.ParseAndValidateWithRepair(ctx, toolInput, call.Schema, call.RepairText)
397				} else {
398					obj, err = schema.ParseAndValidate(toolInput, call.Schema)
399				}
400
401				if err == nil {
402					if !reflect.DeepEqual(obj, lastParsedObject) {
403						if !yield(fantasy.ObjectStreamPart{
404							Type:   fantasy.ObjectStreamPartTypeObject,
405							Object: obj,
406						}) {
407							return
408						}
409						lastParsedObject = obj
410					}
411				}
412
413			case fantasy.StreamPartTypeError:
414				streamErr = part.Error
415				if !yield(fantasy.ObjectStreamPart{
416					Type:  fantasy.ObjectStreamPartTypeError,
417					Error: part.Error,
418				}) {
419					return
420				}
421
422			case fantasy.StreamPartTypeFinish:
423				usage = part.Usage
424				finishReason = part.FinishReason
425
426			case fantasy.StreamPartTypeWarnings:
427				warnings = part.Warnings
428			}
429
430			if len(part.ProviderMetadata) > 0 {
431				providerMetadata = part.ProviderMetadata
432			}
433		}
434
435		if streamErr == nil && lastParsedObject != nil {
436			yield(fantasy.ObjectStreamPart{
437				Type:             fantasy.ObjectStreamPartTypeFinish,
438				Usage:            usage,
439				FinishReason:     finishReason,
440				Warnings:         warnings,
441				ProviderMetadata: providerMetadata,
442			})
443		} else if streamErr == nil && lastParsedObject == nil {
444			yield(fantasy.ObjectStreamPart{
445				Type: fantasy.ObjectStreamPartTypeError,
446				Error: &fantasy.NoObjectGeneratedError{
447					RawText:      accumulated,
448					ParseError:   fmt.Errorf("no valid object generated in stream"),
449					Usage:        usage,
450					FinishReason: finishReason,
451				},
452			})
453		}
454	}, nil
455}
456
457// StreamWithText is a helper for providers without tool or JSON streaming support.
458// It adds the schema to the system prompt and parses the streamed text as JSON progressively.
459func StreamWithText(
460	ctx context.Context,
461	model fantasy.LanguageModel,
462	call fantasy.ObjectCall,
463) (fantasy.ObjectStreamResponse, error) {
464	jsonSchemaMap := schema.ToMap(call.Schema)
465	jsonSchemaBytes, err := json.Marshal(jsonSchemaMap)
466	if err != nil {
467		return nil, fmt.Errorf("failed to marshal schema: %w", err)
468	}
469
470	schemaInstruction := fmt.Sprintf(
471		"You must respond with valid JSON that matches this schema: %s\n"+
472			"Respond ONLY with the JSON object, no additional text or explanation.",
473		string(jsonSchemaBytes),
474	)
475
476	enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
477
478	hasSystem := false
479	for _, msg := range call.Prompt {
480		if msg.Role == fantasy.MessageRoleSystem {
481			hasSystem = true
482			existingText := ""
483			if len(msg.Content) > 0 {
484				if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
485					existingText = textPart.Text
486				}
487			}
488			enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
489		} else {
490			enhancedPrompt = append(enhancedPrompt, msg)
491		}
492	}
493
494	if !hasSystem {
495		enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
496	}
497
498	stream, err := model.Stream(ctx, fantasy.Call{
499		Prompt:           enhancedPrompt,
500		MaxOutputTokens:  call.MaxOutputTokens,
501		Temperature:      call.Temperature,
502		TopP:             call.TopP,
503		TopK:             call.TopK,
504		PresencePenalty:  call.PresencePenalty,
505		FrequencyPenalty: call.FrequencyPenalty,
506		ProviderOptions:  call.ProviderOptions,
507	})
508	if err != nil {
509		return nil, fmt.Errorf("text-based streaming failed: %w", err)
510	}
511
512	return func(yield func(fantasy.ObjectStreamPart) bool) {
513		var accumulated string
514		var lastParsedObject any
515		var usage fantasy.Usage
516		var finishReason fantasy.FinishReason
517		var warnings []fantasy.CallWarning
518		var providerMetadata fantasy.ProviderMetadata
519		var streamErr error
520
521		for part := range stream {
522			switch part.Type {
523			case fantasy.StreamPartTypeTextDelta:
524				accumulated += part.Delta
525
526				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
527
528				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
529					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
530						if !reflect.DeepEqual(obj, lastParsedObject) {
531							if !yield(fantasy.ObjectStreamPart{
532								Type:   fantasy.ObjectStreamPartTypeObject,
533								Object: obj,
534							}) {
535								return
536							}
537							lastParsedObject = obj
538						}
539					}
540				}
541
542				if state == schema.ParseStateFailed && call.RepairText != nil {
543					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
544					if repairErr == nil {
545						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
546						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
547							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
548							if !reflect.DeepEqual(obj2, lastParsedObject) {
549								if !yield(fantasy.ObjectStreamPart{
550									Type:   fantasy.ObjectStreamPartTypeObject,
551									Object: obj2,
552								}) {
553									return
554								}
555								lastParsedObject = obj2
556							}
557						}
558					}
559				}
560
561			case fantasy.StreamPartTypeError:
562				streamErr = part.Error
563				if !yield(fantasy.ObjectStreamPart{
564					Type:  fantasy.ObjectStreamPartTypeError,
565					Error: part.Error,
566				}) {
567					return
568				}
569
570			case fantasy.StreamPartTypeFinish:
571				usage = part.Usage
572				finishReason = part.FinishReason
573
574			case fantasy.StreamPartTypeWarnings:
575				warnings = part.Warnings
576			}
577
578			if len(part.ProviderMetadata) > 0 {
579				providerMetadata = part.ProviderMetadata
580			}
581		}
582
583		if streamErr == nil && lastParsedObject != nil {
584			yield(fantasy.ObjectStreamPart{
585				Type:             fantasy.ObjectStreamPartTypeFinish,
586				Usage:            usage,
587				FinishReason:     finishReason,
588				Warnings:         warnings,
589				ProviderMetadata: providerMetadata,
590			})
591		} else if streamErr == nil && lastParsedObject == nil {
592			yield(fantasy.ObjectStreamPart{
593				Type: fantasy.ObjectStreamPartTypeError,
594				Error: &fantasy.NoObjectGeneratedError{
595					RawText:      accumulated,
596					ParseError:   fmt.Errorf("no valid object generated in stream"),
597					Usage:        usage,
598					FinishReason: finishReason,
599				},
600			})
601		}
602	}, nil
603}
604
605func unmarshal(obj any, target any) error {
606	jsonBytes, err := json.Marshal(obj)
607	if err != nil {
608		return fmt.Errorf("failed to marshal object: %w", err)
609	}
610
611	if err := json.Unmarshal(jsonBytes, target); err != nil {
612		return fmt.Errorf("failed to unmarshal into target type: %w", err)
613	}
614
615	return nil
616}