1package fantasy
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "iter"
8 "reflect"
9
10 "charm.land/fantasy/schema"
11)
12
13// ObjectMode specifies how structured output should be generated.
14type ObjectMode string
15
16const (
17 // ObjectModeAuto lets the provider choose the best approach.
18 ObjectModeAuto ObjectMode = "auto"
19
20 // ObjectModeJSON forces the use of native JSON mode (if supported).
21 ObjectModeJSON ObjectMode = "json"
22
23 // ObjectModeTool forces the use of tool-based approach.
24 ObjectModeTool ObjectMode = "tool"
25
26 // ObjectModeText uses text generation with schema in prompt (fallback for models without tool/JSON support).
27 ObjectModeText ObjectMode = "text"
28)
29
30// ObjectCall represents a request to generate a structured object.
31type ObjectCall struct {
32 Prompt Prompt
33 Schema Schema
34 SchemaName string
35 SchemaDescription string
36
37 MaxOutputTokens *int64
38 Temperature *float64
39 TopP *float64
40 TopK *int64
41 PresencePenalty *float64
42 FrequencyPenalty *float64
43
44 ProviderOptions ProviderOptions
45
46 RepairText schema.ObjectRepairFunc
47}
48
49// ObjectResponse represents the response from a structured object generation.
50type ObjectResponse struct {
51 Object any
52 RawText string
53 Usage Usage
54 FinishReason FinishReason
55 Warnings []CallWarning
56 ProviderMetadata ProviderMetadata
57}
58
59// ObjectStreamPartType indicates the type of stream part.
60type ObjectStreamPartType string
61
62const (
63 // ObjectStreamPartTypeObject is emitted when a new partial object is available.
64 ObjectStreamPartTypeObject ObjectStreamPartType = "object"
65
66 // ObjectStreamPartTypeTextDelta is emitted for text deltas (if model generates text).
67 ObjectStreamPartTypeTextDelta ObjectStreamPartType = "text-delta"
68
69 // ObjectStreamPartTypeError is emitted when an error occurs.
70 ObjectStreamPartTypeError ObjectStreamPartType = "error"
71
72 // ObjectStreamPartTypeFinish is emitted when streaming completes.
73 ObjectStreamPartTypeFinish ObjectStreamPartType = "finish"
74)
75
76// ObjectStreamPart represents a single chunk in the object stream.
77type ObjectStreamPart struct {
78 Type ObjectStreamPartType
79 Object any
80 Delta string
81 Error error
82 Usage Usage
83 FinishReason FinishReason
84 Warnings []CallWarning
85 ProviderMetadata ProviderMetadata
86}
87
88// ObjectStreamResponse is an iterator over ObjectStreamPart.
89type ObjectStreamResponse = iter.Seq[ObjectStreamPart]
90
91// ObjectResult is a typed result wrapper returned by GenerateObject[T].
92type ObjectResult[T any] struct {
93 Object T
94 RawText string
95 Usage Usage
96 FinishReason FinishReason
97 Warnings []CallWarning
98 ProviderMetadata ProviderMetadata
99}
100
101// StreamObjectResult provides typed access to a streaming object generation result.
102type StreamObjectResult[T any] struct {
103 stream ObjectStreamResponse
104 ctx context.Context
105}
106
107// NewStreamObjectResult creates a typed stream result from an untyped stream.
108func NewStreamObjectResult[T any](ctx context.Context, stream ObjectStreamResponse) *StreamObjectResult[T] {
109 return &StreamObjectResult[T]{
110 stream: stream,
111 ctx: ctx,
112 }
113}
114
115// PartialObjectStream returns an iterator that yields progressively more complete objects.
116// Only emits when the object actually changes (deduplication).
117func (s *StreamObjectResult[T]) PartialObjectStream() iter.Seq[T] {
118 return func(yield func(T) bool) {
119 var lastObject T
120 var hasEmitted bool
121
122 for part := range s.stream {
123 if part.Type == ObjectStreamPartTypeObject && part.Object != nil {
124 var current T
125 if err := unmarshalObject(part.Object, ¤t); err != nil {
126 continue
127 }
128
129 if !hasEmitted || !reflect.DeepEqual(current, lastObject) {
130 if !yield(current) {
131 return
132 }
133 lastObject = current
134 hasEmitted = true
135 }
136 }
137 }
138 }
139}
140
141// TextStream returns an iterator that yields text deltas.
142// Useful if the model generates explanatory text alongside the object.
143func (s *StreamObjectResult[T]) TextStream() iter.Seq[string] {
144 return func(yield func(string) bool) {
145 for part := range s.stream {
146 if part.Type == ObjectStreamPartTypeTextDelta && part.Delta != "" {
147 if !yield(part.Delta) {
148 return
149 }
150 }
151 }
152 }
153}
154
155// FullStream returns an iterator that yields all stream parts including errors and metadata.
156func (s *StreamObjectResult[T]) FullStream() iter.Seq[ObjectStreamPart] {
157 return s.stream
158}
159
160// Object waits for the stream to complete and returns the final object.
161// Returns an error if streaming fails or no valid object was generated.
162func (s *StreamObjectResult[T]) Object() (*ObjectResult[T], error) {
163 var finalObject T
164 var usage Usage
165 var finishReason FinishReason
166 var warnings []CallWarning
167 var providerMetadata ProviderMetadata
168 var rawText string
169 var lastError error
170 hasObject := false
171
172 for part := range s.stream {
173 switch part.Type {
174 case ObjectStreamPartTypeObject:
175 if part.Object != nil {
176 if err := unmarshalObject(part.Object, &finalObject); err == nil {
177 hasObject = true
178 if jsonBytes, err := json.Marshal(part.Object); err == nil {
179 rawText = string(jsonBytes)
180 }
181 }
182 }
183
184 case ObjectStreamPartTypeError:
185 lastError = part.Error
186
187 case ObjectStreamPartTypeFinish:
188 usage = part.Usage
189 finishReason = part.FinishReason
190 if len(part.Warnings) > 0 {
191 warnings = part.Warnings
192 }
193 if len(part.ProviderMetadata) > 0 {
194 providerMetadata = part.ProviderMetadata
195 }
196 }
197 }
198
199 if lastError != nil {
200 return nil, lastError
201 }
202
203 if !hasObject {
204 return nil, &NoObjectGeneratedError{
205 RawText: rawText,
206 ParseError: fmt.Errorf("no valid object generated in stream"),
207 Usage: usage,
208 FinishReason: finishReason,
209 }
210 }
211
212 return &ObjectResult[T]{
213 Object: finalObject,
214 RawText: rawText,
215 Usage: usage,
216 FinishReason: finishReason,
217 Warnings: warnings,
218 ProviderMetadata: providerMetadata,
219 }, nil
220}
221
222func unmarshalObject(obj any, target any) error {
223 jsonBytes, err := json.Marshal(obj)
224 if err != nil {
225 return fmt.Errorf("failed to marshal object: %w", err)
226 }
227
228 if err := json.Unmarshal(jsonBytes, target); err != nil {
229 return fmt.Errorf("failed to unmarshal into target type: %w", err)
230 }
231
232 return nil
233}