1package fantasy
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "iter"
8 "reflect"
9
10 "charm.land/fantasy/schema"
11)
12
13// ObjectMode specifies how structured output should be generated.
14type ObjectMode string
15
16const (
17 // ObjectModeAuto lets the provider choose the best approach.
18 ObjectModeAuto ObjectMode = "auto"
19
20 // ObjectModeJSON forces the use of native JSON mode (if supported).
21 ObjectModeJSON ObjectMode = "json"
22
23 // ObjectModeTool forces the use of tool-based approach.
24 ObjectModeTool ObjectMode = "tool"
25
26 // ObjectModeText uses text generation with schema in prompt (fallback for models without tool/JSON support).
27 ObjectModeText ObjectMode = "text"
28)
29
30// ObjectCall represents a request to generate a structured object.
31type ObjectCall struct {
32 Prompt Prompt
33 Schema Schema
34 SchemaName string
35 SchemaDescription string
36
37 MaxOutputTokens *int64
38 Temperature *float64
39 TopP *float64
40 TopK *int64
41 PresencePenalty *float64
42 FrequencyPenalty *float64
43
44 // UserAgent overrides the provider-level User-Agent header for this call.
45 UserAgent string `json:"-"`
46
47 ProviderOptions ProviderOptions
48
49 RepairText schema.ObjectRepairFunc
50}
51
52// ObjectResponse represents the response from a structured object generation.
53type ObjectResponse struct {
54 Object any
55 RawText string
56 Usage Usage
57 FinishReason FinishReason
58 Warnings []CallWarning
59 ProviderMetadata ProviderMetadata
60}
61
62// ObjectStreamPartType indicates the type of stream part.
63type ObjectStreamPartType string
64
65const (
66 // ObjectStreamPartTypeObject is emitted when a new partial object is available.
67 ObjectStreamPartTypeObject ObjectStreamPartType = "object"
68
69 // ObjectStreamPartTypeTextDelta is emitted for text deltas (if model generates text).
70 ObjectStreamPartTypeTextDelta ObjectStreamPartType = "text-delta"
71
72 // ObjectStreamPartTypeError is emitted when an error occurs.
73 ObjectStreamPartTypeError ObjectStreamPartType = "error"
74
75 // ObjectStreamPartTypeFinish is emitted when streaming completes.
76 ObjectStreamPartTypeFinish ObjectStreamPartType = "finish"
77)
78
79// ObjectStreamPart represents a single chunk in the object stream.
80type ObjectStreamPart struct {
81 Type ObjectStreamPartType
82 Object any
83 Delta string
84 Error error
85 Usage Usage
86 FinishReason FinishReason
87 Warnings []CallWarning
88 ProviderMetadata ProviderMetadata
89}
90
91// ObjectStreamResponse is an iterator over ObjectStreamPart.
92type ObjectStreamResponse = iter.Seq[ObjectStreamPart]
93
94// ObjectResult is a typed result wrapper returned by GenerateObject[T].
95type ObjectResult[T any] struct {
96 Object T
97 RawText string
98 Usage Usage
99 FinishReason FinishReason
100 Warnings []CallWarning
101 ProviderMetadata ProviderMetadata
102}
103
104// StreamObjectResult provides typed access to a streaming object generation result.
105type StreamObjectResult[T any] struct {
106 stream ObjectStreamResponse
107 ctx context.Context
108}
109
110// NewStreamObjectResult creates a typed stream result from an untyped stream.
111func NewStreamObjectResult[T any](ctx context.Context, stream ObjectStreamResponse) *StreamObjectResult[T] {
112 return &StreamObjectResult[T]{
113 stream: stream,
114 ctx: ctx,
115 }
116}
117
118// PartialObjectStream returns an iterator that yields progressively more complete objects.
119// Only emits when the object actually changes (deduplication).
120func (s *StreamObjectResult[T]) PartialObjectStream() iter.Seq[T] {
121 return func(yield func(T) bool) {
122 var lastObject T
123 var hasEmitted bool
124
125 for part := range s.stream {
126 if part.Type == ObjectStreamPartTypeObject && part.Object != nil {
127 var current T
128 if err := unmarshalObject(part.Object, ¤t); err != nil {
129 continue
130 }
131
132 if !hasEmitted || !reflect.DeepEqual(current, lastObject) {
133 if !yield(current) {
134 return
135 }
136 lastObject = current
137 hasEmitted = true
138 }
139 }
140 }
141 }
142}
143
144// TextStream returns an iterator that yields text deltas.
145// Useful if the model generates explanatory text alongside the object.
146func (s *StreamObjectResult[T]) TextStream() iter.Seq[string] {
147 return func(yield func(string) bool) {
148 for part := range s.stream {
149 if part.Type == ObjectStreamPartTypeTextDelta && part.Delta != "" {
150 if !yield(part.Delta) {
151 return
152 }
153 }
154 }
155 }
156}
157
158// FullStream returns an iterator that yields all stream parts including errors and metadata.
159func (s *StreamObjectResult[T]) FullStream() iter.Seq[ObjectStreamPart] {
160 return s.stream
161}
162
163// Object waits for the stream to complete and returns the final object.
164// Returns an error if streaming fails or no valid object was generated.
165func (s *StreamObjectResult[T]) Object() (*ObjectResult[T], error) {
166 var finalObject T
167 var usage Usage
168 var finishReason FinishReason
169 var warnings []CallWarning
170 var providerMetadata ProviderMetadata
171 var rawText string
172 var lastError error
173 hasObject := false
174
175 for part := range s.stream {
176 switch part.Type {
177 case ObjectStreamPartTypeObject:
178 if part.Object != nil {
179 if err := unmarshalObject(part.Object, &finalObject); err == nil {
180 hasObject = true
181 if jsonBytes, err := json.Marshal(part.Object); err == nil {
182 rawText = string(jsonBytes)
183 }
184 }
185 }
186
187 case ObjectStreamPartTypeError:
188 lastError = part.Error
189
190 case ObjectStreamPartTypeFinish:
191 usage = part.Usage
192 finishReason = part.FinishReason
193 if len(part.Warnings) > 0 {
194 warnings = part.Warnings
195 }
196 if len(part.ProviderMetadata) > 0 {
197 providerMetadata = part.ProviderMetadata
198 }
199 }
200 }
201
202 if lastError != nil {
203 return nil, lastError
204 }
205
206 if !hasObject {
207 return nil, &NoObjectGeneratedError{
208 RawText: rawText,
209 ParseError: fmt.Errorf("no valid object generated in stream"),
210 Usage: usage,
211 FinishReason: finishReason,
212 }
213 }
214
215 return &ObjectResult[T]{
216 Object: finalObject,
217 RawText: rawText,
218 Usage: usage,
219 FinishReason: finishReason,
220 Warnings: warnings,
221 ProviderMetadata: providerMetadata,
222 }, nil
223}
224
225func unmarshalObject(obj any, target any) error {
226 jsonBytes, err := json.Marshal(obj)
227 if err != nil {
228 return fmt.Errorf("failed to marshal object: %w", err)
229 }
230
231 if err := json.Unmarshal(jsonBytes, target); err != nil {
232 return fmt.Errorf("failed to unmarshal into target type: %w", err)
233 }
234
235 return nil
236}