object.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"iter"
  8	"reflect"
  9
 10	"charm.land/fantasy/schema"
 11)
 12
 13// ObjectMode specifies how structured output should be generated.
 14type ObjectMode string
 15
 16const (
 17	// ObjectModeAuto lets the provider choose the best approach.
 18	ObjectModeAuto ObjectMode = "auto"
 19
 20	// ObjectModeJSON forces the use of native JSON mode (if supported).
 21	ObjectModeJSON ObjectMode = "json"
 22
 23	// ObjectModeTool forces the use of tool-based approach.
 24	ObjectModeTool ObjectMode = "tool"
 25
 26	// ObjectModeText uses text generation with schema in prompt (fallback for models without tool/JSON support).
 27	ObjectModeText ObjectMode = "text"
 28)
 29
 30// ObjectCall represents a request to generate a structured object.
 31type ObjectCall struct {
 32	Prompt            Prompt
 33	Schema            Schema
 34	SchemaName        string
 35	SchemaDescription string
 36
 37	MaxOutputTokens  *int64
 38	Temperature      *float64
 39	TopP             *float64
 40	TopK             *int64
 41	PresencePenalty  *float64
 42	FrequencyPenalty *float64
 43
 44	ProviderOptions ProviderOptions
 45
 46	RepairText schema.ObjectRepairFunc
 47}
 48
 49// ObjectResponse represents the response from a structured object generation.
 50type ObjectResponse struct {
 51	Object           any
 52	RawText          string
 53	Usage            Usage
 54	FinishReason     FinishReason
 55	Warnings         []CallWarning
 56	ProviderMetadata ProviderMetadata
 57}
 58
 59// ObjectStreamPartType indicates the type of stream part.
 60type ObjectStreamPartType string
 61
 62const (
 63	// ObjectStreamPartTypeObject is emitted when a new partial object is available.
 64	ObjectStreamPartTypeObject ObjectStreamPartType = "object"
 65
 66	// ObjectStreamPartTypeTextDelta is emitted for text deltas (if model generates text).
 67	ObjectStreamPartTypeTextDelta ObjectStreamPartType = "text-delta"
 68
 69	// ObjectStreamPartTypeError is emitted when an error occurs.
 70	ObjectStreamPartTypeError ObjectStreamPartType = "error"
 71
 72	// ObjectStreamPartTypeFinish is emitted when streaming completes.
 73	ObjectStreamPartTypeFinish ObjectStreamPartType = "finish"
 74)
 75
 76// ObjectStreamPart represents a single chunk in the object stream.
 77type ObjectStreamPart struct {
 78	Type             ObjectStreamPartType
 79	Object           any
 80	Delta            string
 81	Error            error
 82	Usage            Usage
 83	FinishReason     FinishReason
 84	Warnings         []CallWarning
 85	ProviderMetadata ProviderMetadata
 86}
 87
 88// ObjectStreamResponse is an iterator over ObjectStreamPart.
 89type ObjectStreamResponse = iter.Seq[ObjectStreamPart]
 90
 91// ObjectResult is a typed result wrapper returned by GenerateObject[T].
 92type ObjectResult[T any] struct {
 93	Object           T
 94	RawText          string
 95	Usage            Usage
 96	FinishReason     FinishReason
 97	Warnings         []CallWarning
 98	ProviderMetadata ProviderMetadata
 99}
100
101// StreamObjectResult provides typed access to a streaming object generation result.
102type StreamObjectResult[T any] struct {
103	stream ObjectStreamResponse
104	ctx    context.Context
105}
106
107// NewStreamObjectResult creates a typed stream result from an untyped stream.
108func NewStreamObjectResult[T any](ctx context.Context, stream ObjectStreamResponse) *StreamObjectResult[T] {
109	return &StreamObjectResult[T]{
110		stream: stream,
111		ctx:    ctx,
112	}
113}
114
115// PartialObjectStream returns an iterator that yields progressively more complete objects.
116// Only emits when the object actually changes (deduplication).
117func (s *StreamObjectResult[T]) PartialObjectStream() iter.Seq[T] {
118	return func(yield func(T) bool) {
119		var lastObject T
120		var hasEmitted bool
121
122		for part := range s.stream {
123			if part.Type == ObjectStreamPartTypeObject && part.Object != nil {
124				var current T
125				if err := unmarshalObject(part.Object, &current); err != nil {
126					continue
127				}
128
129				if !hasEmitted || !reflect.DeepEqual(current, lastObject) {
130					if !yield(current) {
131						return
132					}
133					lastObject = current
134					hasEmitted = true
135				}
136			}
137		}
138	}
139}
140
141// TextStream returns an iterator that yields text deltas.
142// Useful if the model generates explanatory text alongside the object.
143func (s *StreamObjectResult[T]) TextStream() iter.Seq[string] {
144	return func(yield func(string) bool) {
145		for part := range s.stream {
146			if part.Type == ObjectStreamPartTypeTextDelta && part.Delta != "" {
147				if !yield(part.Delta) {
148					return
149				}
150			}
151		}
152	}
153}
154
155// FullStream returns an iterator that yields all stream parts including errors and metadata.
156func (s *StreamObjectResult[T]) FullStream() iter.Seq[ObjectStreamPart] {
157	return s.stream
158}
159
160// Object waits for the stream to complete and returns the final object.
161// Returns an error if streaming fails or no valid object was generated.
162func (s *StreamObjectResult[T]) Object() (*ObjectResult[T], error) {
163	var finalObject T
164	var usage Usage
165	var finishReason FinishReason
166	var warnings []CallWarning
167	var providerMetadata ProviderMetadata
168	var rawText string
169	var lastError error
170	hasObject := false
171
172	for part := range s.stream {
173		switch part.Type {
174		case ObjectStreamPartTypeObject:
175			if part.Object != nil {
176				if err := unmarshalObject(part.Object, &finalObject); err == nil {
177					hasObject = true
178					if jsonBytes, err := json.Marshal(part.Object); err == nil {
179						rawText = string(jsonBytes)
180					}
181				}
182			}
183
184		case ObjectStreamPartTypeError:
185			lastError = part.Error
186
187		case ObjectStreamPartTypeFinish:
188			usage = part.Usage
189			finishReason = part.FinishReason
190			if len(part.Warnings) > 0 {
191				warnings = part.Warnings
192			}
193			if len(part.ProviderMetadata) > 0 {
194				providerMetadata = part.ProviderMetadata
195			}
196		}
197	}
198
199	if lastError != nil {
200		return nil, lastError
201	}
202
203	if !hasObject {
204		return nil, &NoObjectGeneratedError{
205			RawText:      rawText,
206			ParseError:   fmt.Errorf("no valid object generated in stream"),
207			Usage:        usage,
208			FinishReason: finishReason,
209		}
210	}
211
212	return &ObjectResult[T]{
213		Object:           finalObject,
214		RawText:          rawText,
215		Usage:            usage,
216		FinishReason:     finishReason,
217		Warnings:         warnings,
218		ProviderMetadata: providerMetadata,
219	}, nil
220}
221
222func unmarshalObject(obj any, target any) error {
223	jsonBytes, err := json.Marshal(obj)
224	if err != nil {
225		return fmt.Errorf("failed to marshal object: %w", err)
226	}
227
228	if err := json.Unmarshal(jsonBytes, target); err != nil {
229		return fmt.Errorf("failed to unmarshal into target type: %w", err)
230	}
231
232	return nil
233}