object.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"iter"
  8	"reflect"
  9
 10	"charm.land/fantasy/schema"
 11)
 12
 13// ObjectMode specifies how structured output should be generated.
 14type ObjectMode string
 15
 16const (
 17	// ObjectModeAuto lets the provider choose the best approach.
 18	ObjectModeAuto ObjectMode = "auto"
 19
 20	// ObjectModeJSON forces the use of native JSON mode (if supported).
 21	ObjectModeJSON ObjectMode = "json"
 22
 23	// ObjectModeTool forces the use of tool-based approach.
 24	ObjectModeTool ObjectMode = "tool"
 25
 26	// ObjectModeText uses text generation with schema in prompt (fallback for models without tool/JSON support).
 27	ObjectModeText ObjectMode = "text"
 28)
 29
 30// ObjectCall represents a request to generate a structured object.
 31type ObjectCall struct {
 32	Prompt            Prompt
 33	Schema            Schema
 34	SchemaName        string
 35	SchemaDescription string
 36
 37	MaxOutputTokens  *int64
 38	Temperature      *float64
 39	TopP             *float64
 40	TopK             *int64
 41	PresencePenalty  *float64
 42	FrequencyPenalty *float64
 43
 44	// UserAgent overrides the provider-level User-Agent header for this call.
 45	UserAgent string `json:"-"`
 46
 47	ProviderOptions ProviderOptions
 48
 49	RepairText schema.ObjectRepairFunc
 50}
 51
 52// ObjectResponse represents the response from a structured object generation.
 53type ObjectResponse struct {
 54	Object           any
 55	RawText          string
 56	Usage            Usage
 57	FinishReason     FinishReason
 58	Warnings         []CallWarning
 59	ProviderMetadata ProviderMetadata
 60}
 61
 62// ObjectStreamPartType indicates the type of stream part.
 63type ObjectStreamPartType string
 64
 65const (
 66	// ObjectStreamPartTypeObject is emitted when a new partial object is available.
 67	ObjectStreamPartTypeObject ObjectStreamPartType = "object"
 68
 69	// ObjectStreamPartTypeTextDelta is emitted for text deltas (if model generates text).
 70	ObjectStreamPartTypeTextDelta ObjectStreamPartType = "text-delta"
 71
 72	// ObjectStreamPartTypeError is emitted when an error occurs.
 73	ObjectStreamPartTypeError ObjectStreamPartType = "error"
 74
 75	// ObjectStreamPartTypeFinish is emitted when streaming completes.
 76	ObjectStreamPartTypeFinish ObjectStreamPartType = "finish"
 77)
 78
 79// ObjectStreamPart represents a single chunk in the object stream.
 80type ObjectStreamPart struct {
 81	Type             ObjectStreamPartType
 82	Object           any
 83	Delta            string
 84	Error            error
 85	Usage            Usage
 86	FinishReason     FinishReason
 87	Warnings         []CallWarning
 88	ProviderMetadata ProviderMetadata
 89}
 90
 91// ObjectStreamResponse is an iterator over ObjectStreamPart.
 92type ObjectStreamResponse = iter.Seq[ObjectStreamPart]
 93
 94// ObjectResult is a typed result wrapper returned by GenerateObject[T].
 95type ObjectResult[T any] struct {
 96	Object           T
 97	RawText          string
 98	Usage            Usage
 99	FinishReason     FinishReason
100	Warnings         []CallWarning
101	ProviderMetadata ProviderMetadata
102}
103
104// StreamObjectResult provides typed access to a streaming object generation result.
105type StreamObjectResult[T any] struct {
106	stream ObjectStreamResponse
107	ctx    context.Context
108}
109
110// NewStreamObjectResult creates a typed stream result from an untyped stream.
111func NewStreamObjectResult[T any](ctx context.Context, stream ObjectStreamResponse) *StreamObjectResult[T] {
112	return &StreamObjectResult[T]{
113		stream: stream,
114		ctx:    ctx,
115	}
116}
117
118// PartialObjectStream returns an iterator that yields progressively more complete objects.
119// Only emits when the object actually changes (deduplication).
120func (s *StreamObjectResult[T]) PartialObjectStream() iter.Seq[T] {
121	return func(yield func(T) bool) {
122		var lastObject T
123		var hasEmitted bool
124
125		for part := range s.stream {
126			if part.Type == ObjectStreamPartTypeObject && part.Object != nil {
127				var current T
128				if err := unmarshalObject(part.Object, &current); err != nil {
129					continue
130				}
131
132				if !hasEmitted || !reflect.DeepEqual(current, lastObject) {
133					if !yield(current) {
134						return
135					}
136					lastObject = current
137					hasEmitted = true
138				}
139			}
140		}
141	}
142}
143
144// TextStream returns an iterator that yields text deltas.
145// Useful if the model generates explanatory text alongside the object.
146func (s *StreamObjectResult[T]) TextStream() iter.Seq[string] {
147	return func(yield func(string) bool) {
148		for part := range s.stream {
149			if part.Type == ObjectStreamPartTypeTextDelta && part.Delta != "" {
150				if !yield(part.Delta) {
151					return
152				}
153			}
154		}
155	}
156}
157
158// FullStream returns an iterator that yields all stream parts including errors and metadata.
159func (s *StreamObjectResult[T]) FullStream() iter.Seq[ObjectStreamPart] {
160	return s.stream
161}
162
163// Object waits for the stream to complete and returns the final object.
164// Returns an error if streaming fails or no valid object was generated.
165func (s *StreamObjectResult[T]) Object() (*ObjectResult[T], error) {
166	var finalObject T
167	var usage Usage
168	var finishReason FinishReason
169	var warnings []CallWarning
170	var providerMetadata ProviderMetadata
171	var rawText string
172	var lastError error
173	hasObject := false
174
175	for part := range s.stream {
176		switch part.Type {
177		case ObjectStreamPartTypeObject:
178			if part.Object != nil {
179				if err := unmarshalObject(part.Object, &finalObject); err == nil {
180					hasObject = true
181					if jsonBytes, err := json.Marshal(part.Object); err == nil {
182						rawText = string(jsonBytes)
183					}
184				}
185			}
186
187		case ObjectStreamPartTypeError:
188			lastError = part.Error
189
190		case ObjectStreamPartTypeFinish:
191			usage = part.Usage
192			finishReason = part.FinishReason
193			if len(part.Warnings) > 0 {
194				warnings = part.Warnings
195			}
196			if len(part.ProviderMetadata) > 0 {
197				providerMetadata = part.ProviderMetadata
198			}
199		}
200	}
201
202	if lastError != nil {
203		return nil, lastError
204	}
205
206	if !hasObject {
207		return nil, &NoObjectGeneratedError{
208			RawText:      rawText,
209			ParseError:   fmt.Errorf("no valid object generated in stream"),
210			Usage:        usage,
211			FinishReason: finishReason,
212		}
213	}
214
215	return &ObjectResult[T]{
216		Object:           finalObject,
217		RawText:          rawText,
218		Usage:            usage,
219		FinishReason:     finishReason,
220		Warnings:         warnings,
221		ProviderMetadata: providerMetadata,
222	}, nil
223}
224
225func unmarshalObject(obj any, target any) error {
226	jsonBytes, err := json.Marshal(obj)
227	if err != nil {
228		return fmt.Errorf("failed to marshal object: %w", err)
229	}
230
231	if err := json.Unmarshal(jsonBytes, target); err != nil {
232		return fmt.Errorf("failed to unmarshal into target type: %w", err)
233	}
234
235	return nil
236}