1package fantasy
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "iter"
8 "reflect"
9
10 "charm.land/fantasy/schema"
11)
12
13// ObjectMode specifies how structured output should be generated.
14type ObjectMode string
15
16const (
17 // ObjectModeAuto lets the provider choose the best approach.
18 ObjectModeAuto ObjectMode = "auto"
19
20 // ObjectModeJSON forces the use of native JSON mode (if supported).
21 ObjectModeJSON ObjectMode = "json"
22
23 // ObjectModeTool forces the use of tool-based approach.
24 ObjectModeTool ObjectMode = "tool"
25
26 // ObjectModeText uses text generation with schema in prompt (fallback for models without tool/JSON support).
27 ObjectModeText ObjectMode = "text"
28)
29
30// ObjectCall represents a request to generate a structured object.
31type ObjectCall struct {
32 Prompt Prompt
33 Schema Schema
34 SchemaName string
35 SchemaDescription string
36
37 MaxOutputTokens *int64
38 Temperature *float64
39 TopP *float64
40 TopK *int64
41 PresencePenalty *float64
42 FrequencyPenalty *float64
43
44 // UserAgent overrides the provider-level User-Agent header for this call.
45 UserAgent string `json:"-"`
46 // ModelSegment overrides the provider-level model segment for this call.
47 ModelSegment string `json:"-"`
48
49 ProviderOptions ProviderOptions
50
51 RepairText schema.ObjectRepairFunc
52}
53
54// ObjectResponse represents the response from a structured object generation.
55type ObjectResponse struct {
56 Object any
57 RawText string
58 Usage Usage
59 FinishReason FinishReason
60 Warnings []CallWarning
61 ProviderMetadata ProviderMetadata
62}
63
64// ObjectStreamPartType indicates the type of stream part.
65type ObjectStreamPartType string
66
67const (
68 // ObjectStreamPartTypeObject is emitted when a new partial object is available.
69 ObjectStreamPartTypeObject ObjectStreamPartType = "object"
70
71 // ObjectStreamPartTypeTextDelta is emitted for text deltas (if model generates text).
72 ObjectStreamPartTypeTextDelta ObjectStreamPartType = "text-delta"
73
74 // ObjectStreamPartTypeError is emitted when an error occurs.
75 ObjectStreamPartTypeError ObjectStreamPartType = "error"
76
77 // ObjectStreamPartTypeFinish is emitted when streaming completes.
78 ObjectStreamPartTypeFinish ObjectStreamPartType = "finish"
79)
80
81// ObjectStreamPart represents a single chunk in the object stream.
82type ObjectStreamPart struct {
83 Type ObjectStreamPartType
84 Object any
85 Delta string
86 Error error
87 Usage Usage
88 FinishReason FinishReason
89 Warnings []CallWarning
90 ProviderMetadata ProviderMetadata
91}
92
93// ObjectStreamResponse is an iterator over ObjectStreamPart.
94type ObjectStreamResponse = iter.Seq[ObjectStreamPart]
95
96// ObjectResult is a typed result wrapper returned by GenerateObject[T].
97type ObjectResult[T any] struct {
98 Object T
99 RawText string
100 Usage Usage
101 FinishReason FinishReason
102 Warnings []CallWarning
103 ProviderMetadata ProviderMetadata
104}
105
106// StreamObjectResult provides typed access to a streaming object generation result.
107type StreamObjectResult[T any] struct {
108 stream ObjectStreamResponse
109 ctx context.Context
110}
111
112// NewStreamObjectResult creates a typed stream result from an untyped stream.
113func NewStreamObjectResult[T any](ctx context.Context, stream ObjectStreamResponse) *StreamObjectResult[T] {
114 return &StreamObjectResult[T]{
115 stream: stream,
116 ctx: ctx,
117 }
118}
119
120// PartialObjectStream returns an iterator that yields progressively more complete objects.
121// Only emits when the object actually changes (deduplication).
122func (s *StreamObjectResult[T]) PartialObjectStream() iter.Seq[T] {
123 return func(yield func(T) bool) {
124 var lastObject T
125 var hasEmitted bool
126
127 for part := range s.stream {
128 if part.Type == ObjectStreamPartTypeObject && part.Object != nil {
129 var current T
130 if err := unmarshalObject(part.Object, ¤t); err != nil {
131 continue
132 }
133
134 if !hasEmitted || !reflect.DeepEqual(current, lastObject) {
135 if !yield(current) {
136 return
137 }
138 lastObject = current
139 hasEmitted = true
140 }
141 }
142 }
143 }
144}
145
146// TextStream returns an iterator that yields text deltas.
147// Useful if the model generates explanatory text alongside the object.
148func (s *StreamObjectResult[T]) TextStream() iter.Seq[string] {
149 return func(yield func(string) bool) {
150 for part := range s.stream {
151 if part.Type == ObjectStreamPartTypeTextDelta && part.Delta != "" {
152 if !yield(part.Delta) {
153 return
154 }
155 }
156 }
157 }
158}
159
160// FullStream returns an iterator that yields all stream parts including errors and metadata.
161func (s *StreamObjectResult[T]) FullStream() iter.Seq[ObjectStreamPart] {
162 return s.stream
163}
164
165// Object waits for the stream to complete and returns the final object.
166// Returns an error if streaming fails or no valid object was generated.
167func (s *StreamObjectResult[T]) Object() (*ObjectResult[T], error) {
168 var finalObject T
169 var usage Usage
170 var finishReason FinishReason
171 var warnings []CallWarning
172 var providerMetadata ProviderMetadata
173 var rawText string
174 var lastError error
175 hasObject := false
176
177 for part := range s.stream {
178 switch part.Type {
179 case ObjectStreamPartTypeObject:
180 if part.Object != nil {
181 if err := unmarshalObject(part.Object, &finalObject); err == nil {
182 hasObject = true
183 if jsonBytes, err := json.Marshal(part.Object); err == nil {
184 rawText = string(jsonBytes)
185 }
186 }
187 }
188
189 case ObjectStreamPartTypeError:
190 lastError = part.Error
191
192 case ObjectStreamPartTypeFinish:
193 usage = part.Usage
194 finishReason = part.FinishReason
195 if len(part.Warnings) > 0 {
196 warnings = part.Warnings
197 }
198 if len(part.ProviderMetadata) > 0 {
199 providerMetadata = part.ProviderMetadata
200 }
201 }
202 }
203
204 if lastError != nil {
205 return nil, lastError
206 }
207
208 if !hasObject {
209 return nil, &NoObjectGeneratedError{
210 RawText: rawText,
211 ParseError: fmt.Errorf("no valid object generated in stream"),
212 Usage: usage,
213 FinishReason: finishReason,
214 }
215 }
216
217 return &ObjectResult[T]{
218 Object: finalObject,
219 RawText: rawText,
220 Usage: usage,
221 FinishReason: finishReason,
222 Warnings: warnings,
223 ProviderMetadata: providerMetadata,
224 }, nil
225}
226
227func unmarshalObject(obj any, target any) error {
228 jsonBytes, err := json.Marshal(obj)
229 if err != nil {
230 return fmt.Errorf("failed to marshal object: %w", err)
231 }
232
233 if err := json.Unmarshal(jsonBytes, target); err != nil {
234 return fmt.Errorf("failed to unmarshal into target type: %w", err)
235 }
236
237 return nil
238}