object.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"iter"
  8	"reflect"
  9
 10	"charm.land/fantasy/schema"
 11)
 12
 13// ObjectMode specifies how structured output should be generated.
 14type ObjectMode string
 15
 16const (
 17	// ObjectModeAuto lets the provider choose the best approach.
 18	ObjectModeAuto ObjectMode = "auto"
 19
 20	// ObjectModeJSON forces the use of native JSON mode (if supported).
 21	ObjectModeJSON ObjectMode = "json"
 22
 23	// ObjectModeTool forces the use of tool-based approach.
 24	ObjectModeTool ObjectMode = "tool"
 25
 26	// ObjectModeText uses text generation with schema in prompt (fallback for models without tool/JSON support).
 27	ObjectModeText ObjectMode = "text"
 28)
 29
 30// ObjectCall represents a request to generate a structured object.
 31type ObjectCall struct {
 32	Prompt            Prompt
 33	Schema            Schema
 34	SchemaName        string
 35	SchemaDescription string
 36
 37	MaxOutputTokens  *int64
 38	Temperature      *float64
 39	TopP             *float64
 40	TopK             *int64
 41	PresencePenalty  *float64
 42	FrequencyPenalty *float64
 43
 44	// UserAgent overrides the provider-level User-Agent header for this call.
 45	UserAgent string `json:"-"`
 46	// ModelSegment overrides the provider-level model segment for this call.
 47	ModelSegment string `json:"-"`
 48
 49	ProviderOptions ProviderOptions
 50
 51	RepairText schema.ObjectRepairFunc
 52}
 53
 54// ObjectResponse represents the response from a structured object generation.
 55type ObjectResponse struct {
 56	Object           any
 57	RawText          string
 58	Usage            Usage
 59	FinishReason     FinishReason
 60	Warnings         []CallWarning
 61	ProviderMetadata ProviderMetadata
 62}
 63
 64// ObjectStreamPartType indicates the type of stream part.
 65type ObjectStreamPartType string
 66
 67const (
 68	// ObjectStreamPartTypeObject is emitted when a new partial object is available.
 69	ObjectStreamPartTypeObject ObjectStreamPartType = "object"
 70
 71	// ObjectStreamPartTypeTextDelta is emitted for text deltas (if model generates text).
 72	ObjectStreamPartTypeTextDelta ObjectStreamPartType = "text-delta"
 73
 74	// ObjectStreamPartTypeError is emitted when an error occurs.
 75	ObjectStreamPartTypeError ObjectStreamPartType = "error"
 76
 77	// ObjectStreamPartTypeFinish is emitted when streaming completes.
 78	ObjectStreamPartTypeFinish ObjectStreamPartType = "finish"
 79)
 80
 81// ObjectStreamPart represents a single chunk in the object stream.
 82type ObjectStreamPart struct {
 83	Type             ObjectStreamPartType
 84	Object           any
 85	Delta            string
 86	Error            error
 87	Usage            Usage
 88	FinishReason     FinishReason
 89	Warnings         []CallWarning
 90	ProviderMetadata ProviderMetadata
 91}
 92
 93// ObjectStreamResponse is an iterator over ObjectStreamPart.
 94type ObjectStreamResponse = iter.Seq[ObjectStreamPart]
 95
 96// ObjectResult is a typed result wrapper returned by GenerateObject[T].
 97type ObjectResult[T any] struct {
 98	Object           T
 99	RawText          string
100	Usage            Usage
101	FinishReason     FinishReason
102	Warnings         []CallWarning
103	ProviderMetadata ProviderMetadata
104}
105
106// StreamObjectResult provides typed access to a streaming object generation result.
107type StreamObjectResult[T any] struct {
108	stream ObjectStreamResponse
109	ctx    context.Context
110}
111
112// NewStreamObjectResult creates a typed stream result from an untyped stream.
113func NewStreamObjectResult[T any](ctx context.Context, stream ObjectStreamResponse) *StreamObjectResult[T] {
114	return &StreamObjectResult[T]{
115		stream: stream,
116		ctx:    ctx,
117	}
118}
119
120// PartialObjectStream returns an iterator that yields progressively more complete objects.
121// Only emits when the object actually changes (deduplication).
122func (s *StreamObjectResult[T]) PartialObjectStream() iter.Seq[T] {
123	return func(yield func(T) bool) {
124		var lastObject T
125		var hasEmitted bool
126
127		for part := range s.stream {
128			if part.Type == ObjectStreamPartTypeObject && part.Object != nil {
129				var current T
130				if err := unmarshalObject(part.Object, &current); err != nil {
131					continue
132				}
133
134				if !hasEmitted || !reflect.DeepEqual(current, lastObject) {
135					if !yield(current) {
136						return
137					}
138					lastObject = current
139					hasEmitted = true
140				}
141			}
142		}
143	}
144}
145
146// TextStream returns an iterator that yields text deltas.
147// Useful if the model generates explanatory text alongside the object.
148func (s *StreamObjectResult[T]) TextStream() iter.Seq[string] {
149	return func(yield func(string) bool) {
150		for part := range s.stream {
151			if part.Type == ObjectStreamPartTypeTextDelta && part.Delta != "" {
152				if !yield(part.Delta) {
153					return
154				}
155			}
156		}
157	}
158}
159
160// FullStream returns an iterator that yields all stream parts including errors and metadata.
161func (s *StreamObjectResult[T]) FullStream() iter.Seq[ObjectStreamPart] {
162	return s.stream
163}
164
165// Object waits for the stream to complete and returns the final object.
166// Returns an error if streaming fails or no valid object was generated.
167func (s *StreamObjectResult[T]) Object() (*ObjectResult[T], error) {
168	var finalObject T
169	var usage Usage
170	var finishReason FinishReason
171	var warnings []CallWarning
172	var providerMetadata ProviderMetadata
173	var rawText string
174	var lastError error
175	hasObject := false
176
177	for part := range s.stream {
178		switch part.Type {
179		case ObjectStreamPartTypeObject:
180			if part.Object != nil {
181				if err := unmarshalObject(part.Object, &finalObject); err == nil {
182					hasObject = true
183					if jsonBytes, err := json.Marshal(part.Object); err == nil {
184						rawText = string(jsonBytes)
185					}
186				}
187			}
188
189		case ObjectStreamPartTypeError:
190			lastError = part.Error
191
192		case ObjectStreamPartTypeFinish:
193			usage = part.Usage
194			finishReason = part.FinishReason
195			if len(part.Warnings) > 0 {
196				warnings = part.Warnings
197			}
198			if len(part.ProviderMetadata) > 0 {
199				providerMetadata = part.ProviderMetadata
200			}
201		}
202	}
203
204	if lastError != nil {
205		return nil, lastError
206	}
207
208	if !hasObject {
209		return nil, &NoObjectGeneratedError{
210			RawText:      rawText,
211			ParseError:   fmt.Errorf("no valid object generated in stream"),
212			Usage:        usage,
213			FinishReason: finishReason,
214		}
215	}
216
217	return &ObjectResult[T]{
218		Object:           finalObject,
219		RawText:          rawText,
220		Usage:            usage,
221		FinishReason:     finishReason,
222		Warnings:         warnings,
223		ProviderMetadata: providerMetadata,
224	}, nil
225}
226
227func unmarshalObject(obj any, target any) error {
228	jsonBytes, err := json.Marshal(obj)
229	if err != nil {
230		return fmt.Errorf("failed to marshal object: %w", err)
231	}
232
233	if err := json.Unmarshal(jsonBytes, target); err != nil {
234		return fmt.Errorf("failed to unmarshal into target type: %w", err)
235	}
236
237	return nil
238}