1// Package object provides utilities for generating structured objects with automatic schema generation.
2// It simplifies working with typed structured outputs by handling schema reflection and unmarshaling.
3package object
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "reflect"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/schema"
13)
14
15// Generate generates a structured object that matches the given type T.
16// The schema is automatically generated from T using reflection.
17//
18// Example:
19//
20// type Recipe struct {
21// Name string `json:"name"`
22// Ingredients []string `json:"ingredients"`
23// }
24//
25// result, err := object.Generate[Recipe](ctx, model, fantasy.ObjectCall{
26// Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
27// })
28func Generate[T any](
29 ctx context.Context,
30 model fantasy.LanguageModel,
31 opts fantasy.ObjectCall,
32) (*fantasy.ObjectResult[T], error) {
33 var zero T
34 s := schema.Generate(reflect.TypeOf(zero))
35 opts.Schema = s
36
37 resp, err := model.GenerateObject(ctx, opts)
38 if err != nil {
39 return nil, err
40 }
41
42 var result T
43 if err := unmarshal(resp.Object, &result); err != nil {
44 return nil, fmt.Errorf("failed to unmarshal to %T: %w", result, err)
45 }
46
47 return &fantasy.ObjectResult[T]{
48 Object: result,
49 RawText: resp.RawText,
50 Usage: resp.Usage,
51 FinishReason: resp.FinishReason,
52 Warnings: resp.Warnings,
53 ProviderMetadata: resp.ProviderMetadata,
54 }, nil
55}
56
57// Stream streams a structured object that matches the given type T.
58// Returns a StreamObjectResult[T] with progressive updates and deduplication.
59//
60// Example:
61//
62// stream, err := object.Stream[Recipe](ctx, model, fantasy.ObjectCall{
63// Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
64// })
65//
66// for partial := range stream.PartialObjectStream() {
67// fmt.Printf("Progress: %s\n", partial.Name)
68// }
69//
70// result, err := stream.Object() // Wait for final result
71func Stream[T any](
72 ctx context.Context,
73 model fantasy.LanguageModel,
74 opts fantasy.ObjectCall,
75) (*fantasy.StreamObjectResult[T], error) {
76 var zero T
77 s := schema.Generate(reflect.TypeOf(zero))
78 opts.Schema = s
79
80 stream, err := model.StreamObject(ctx, opts)
81 if err != nil {
82 return nil, err
83 }
84
85 return fantasy.NewStreamObjectResult[T](ctx, stream), nil
86}
87
88// GenerateWithTool is a helper for providers without native JSON mode.
89// It converts the schema to a tool definition, forces the model to call it,
90// and extracts the tool's input as the structured output.
91func GenerateWithTool(
92 ctx context.Context,
93 model fantasy.LanguageModel,
94 call fantasy.ObjectCall,
95) (*fantasy.ObjectResponse, error) {
96 toolName := call.SchemaName
97 if toolName == "" {
98 toolName = "generate_object"
99 }
100
101 toolDescription := call.SchemaDescription
102 if toolDescription == "" {
103 toolDescription = "Generate a structured object matching the schema"
104 }
105
106 tool := fantasy.FunctionTool{
107 Name: toolName,
108 Description: toolDescription,
109 InputSchema: schema.ToMap(call.Schema),
110 }
111
112 toolChoice := fantasy.SpecificToolChoice(tool.Name)
113 resp, err := model.Generate(ctx, fantasy.Call{
114 Prompt: call.Prompt,
115 Tools: []fantasy.Tool{tool},
116 ToolChoice: &toolChoice,
117 MaxOutputTokens: call.MaxOutputTokens,
118 Temperature: call.Temperature,
119 TopP: call.TopP,
120 TopK: call.TopK,
121 PresencePenalty: call.PresencePenalty,
122 FrequencyPenalty: call.FrequencyPenalty,
123 ProviderOptions: call.ProviderOptions,
124 })
125 if err != nil {
126 return nil, fmt.Errorf("tool-based generation failed: %w", err)
127 }
128
129 toolCalls := resp.Content.ToolCalls()
130 if len(toolCalls) == 0 {
131 return nil, &fantasy.NoObjectGeneratedError{
132 RawText: resp.Content.Text(),
133 ParseError: fmt.Errorf("no tool call generated"),
134 Usage: resp.Usage,
135 FinishReason: resp.FinishReason,
136 }
137 }
138
139 toolCall := toolCalls[0]
140
141 var obj any
142 if call.RepairText != nil {
143 obj, err = schema.ParseAndValidateWithRepair(ctx, toolCall.Input, call.Schema, call.RepairText)
144 } else {
145 obj, err = schema.ParseAndValidate(toolCall.Input, call.Schema)
146 }
147
148 if err != nil {
149 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
150 nogErr.Usage = resp.Usage
151 nogErr.FinishReason = resp.FinishReason
152 }
153 return nil, err
154 }
155
156 return &fantasy.ObjectResponse{
157 Object: obj,
158 RawText: toolCall.Input,
159 Usage: resp.Usage,
160 FinishReason: resp.FinishReason,
161 Warnings: resp.Warnings,
162 ProviderMetadata: resp.ProviderMetadata,
163 }, nil
164}
165
166// GenerateWithText is a helper for providers without tool or JSON mode support.
167// It adds the schema to the system prompt and parses the text response as JSON.
168// This is a fallback for older models or simple providers.
169func GenerateWithText(
170 ctx context.Context,
171 model fantasy.LanguageModel,
172 call fantasy.ObjectCall,
173) (*fantasy.ObjectResponse, error) {
174 jsonSchemaBytes, err := json.Marshal(call.Schema)
175 if err != nil {
176 return nil, fmt.Errorf("failed to marshal schema: %w", err)
177 }
178
179 schemaInstruction := fmt.Sprintf(
180 "You must respond with valid JSON that matches this schema: %s\n"+
181 "Respond ONLY with the JSON object, no additional text or explanation.",
182 string(jsonSchemaBytes),
183 )
184
185 enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
186
187 hasSystem := false
188 for _, msg := range call.Prompt {
189 if msg.Role == fantasy.MessageRoleSystem {
190 hasSystem = true
191 existingText := ""
192 if len(msg.Content) > 0 {
193 if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
194 existingText = textPart.Text
195 }
196 }
197 enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
198 } else {
199 enhancedPrompt = append(enhancedPrompt, msg)
200 }
201 }
202
203 if !hasSystem {
204 enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
205 }
206
207 resp, err := model.Generate(ctx, fantasy.Call{
208 Prompt: enhancedPrompt,
209 MaxOutputTokens: call.MaxOutputTokens,
210 Temperature: call.Temperature,
211 TopP: call.TopP,
212 TopK: call.TopK,
213 PresencePenalty: call.PresencePenalty,
214 FrequencyPenalty: call.FrequencyPenalty,
215 ProviderOptions: call.ProviderOptions,
216 })
217 if err != nil {
218 return nil, fmt.Errorf("text-based generation failed: %w", err)
219 }
220
221 textContent := resp.Content.Text()
222 if textContent == "" {
223 return nil, &fantasy.NoObjectGeneratedError{
224 RawText: "",
225 ParseError: fmt.Errorf("no text content in response"),
226 Usage: resp.Usage,
227 FinishReason: resp.FinishReason,
228 }
229 }
230
231 var obj any
232 if call.RepairText != nil {
233 obj, err = schema.ParseAndValidateWithRepair(ctx, textContent, call.Schema, call.RepairText)
234 } else {
235 obj, err = schema.ParseAndValidate(textContent, call.Schema)
236 }
237
238 if err != nil {
239 if nogErr, ok := err.(*schema.ParseError); ok {
240 return nil, &fantasy.NoObjectGeneratedError{
241 RawText: nogErr.RawText,
242 ParseError: nogErr.ParseError,
243 ValidationError: nogErr.ValidationError,
244 Usage: resp.Usage,
245 FinishReason: resp.FinishReason,
246 }
247 }
248 return nil, err
249 }
250
251 return &fantasy.ObjectResponse{
252 Object: obj,
253 RawText: textContent,
254 Usage: resp.Usage,
255 FinishReason: resp.FinishReason,
256 Warnings: resp.Warnings,
257 ProviderMetadata: resp.ProviderMetadata,
258 }, nil
259}
260
261// StreamWithTool is a helper for providers without native JSON streaming.
262// It uses streaming tool calls to extract and parse the structured output progressively.
263func StreamWithTool(
264 ctx context.Context,
265 model fantasy.LanguageModel,
266 call fantasy.ObjectCall,
267) (fantasy.ObjectStreamResponse, error) {
268 // Create a tool from the schema
269 toolName := call.SchemaName
270 if toolName == "" {
271 toolName = "generate_object"
272 }
273
274 toolDescription := call.SchemaDescription
275 if toolDescription == "" {
276 toolDescription = "Generate a structured object matching the schema"
277 }
278
279 tool := fantasy.FunctionTool{
280 Name: toolName,
281 Description: toolDescription,
282 InputSchema: schema.ToMap(call.Schema),
283 }
284
285 // Make a streaming Generate call with forced tool choice
286 toolChoice := fantasy.SpecificToolChoice(tool.Name)
287 stream, err := model.Stream(ctx, fantasy.Call{
288 Prompt: call.Prompt,
289 Tools: []fantasy.Tool{tool},
290 ToolChoice: &toolChoice,
291 MaxOutputTokens: call.MaxOutputTokens,
292 Temperature: call.Temperature,
293 TopP: call.TopP,
294 TopK: call.TopK,
295 PresencePenalty: call.PresencePenalty,
296 FrequencyPenalty: call.FrequencyPenalty,
297 ProviderOptions: call.ProviderOptions,
298 })
299 if err != nil {
300 return nil, fmt.Errorf("tool-based streaming failed: %w", err)
301 }
302
303 // Convert the text stream to object stream parts
304 return func(yield func(fantasy.ObjectStreamPart) bool) {
305 var accumulated string
306 var lastParsedObject any
307 var usage fantasy.Usage
308 var finishReason fantasy.FinishReason
309 var warnings []fantasy.CallWarning
310 var providerMetadata fantasy.ProviderMetadata
311 var streamErr error
312
313 for part := range stream {
314 switch part.Type {
315 case fantasy.StreamPartTypeTextDelta:
316 accumulated += part.Delta
317
318 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
319
320 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
321 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
322 if !reflect.DeepEqual(obj, lastParsedObject) {
323 if !yield(fantasy.ObjectStreamPart{
324 Type: fantasy.ObjectStreamPartTypeObject,
325 Object: obj,
326 }) {
327 return
328 }
329 lastParsedObject = obj
330 }
331 }
332 }
333
334 if state == schema.ParseStateFailed && call.RepairText != nil {
335 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
336 if repairErr == nil {
337 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
338 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
339 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
340 if !reflect.DeepEqual(obj2, lastParsedObject) {
341 if !yield(fantasy.ObjectStreamPart{
342 Type: fantasy.ObjectStreamPartTypeObject,
343 Object: obj2,
344 }) {
345 return
346 }
347 lastParsedObject = obj2
348 }
349 }
350 }
351 }
352
353 case fantasy.StreamPartTypeToolInputDelta:
354 accumulated += part.Delta
355
356 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
357 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
358 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
359 if !reflect.DeepEqual(obj, lastParsedObject) {
360 if !yield(fantasy.ObjectStreamPart{
361 Type: fantasy.ObjectStreamPartTypeObject,
362 Object: obj,
363 }) {
364 return
365 }
366 lastParsedObject = obj
367 }
368 }
369 }
370
371 if state == schema.ParseStateFailed && call.RepairText != nil {
372 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
373 if repairErr == nil {
374 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
375 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
376 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
377 if !reflect.DeepEqual(obj2, lastParsedObject) {
378 if !yield(fantasy.ObjectStreamPart{
379 Type: fantasy.ObjectStreamPartTypeObject,
380 Object: obj2,
381 }) {
382 return
383 }
384 lastParsedObject = obj2
385 }
386 }
387 }
388 }
389
390 case fantasy.StreamPartTypeToolCall:
391 toolInput := part.ToolCallInput
392
393 var obj any
394 var err error
395 if call.RepairText != nil {
396 obj, err = schema.ParseAndValidateWithRepair(ctx, toolInput, call.Schema, call.RepairText)
397 } else {
398 obj, err = schema.ParseAndValidate(toolInput, call.Schema)
399 }
400
401 if err == nil {
402 if !reflect.DeepEqual(obj, lastParsedObject) {
403 if !yield(fantasy.ObjectStreamPart{
404 Type: fantasy.ObjectStreamPartTypeObject,
405 Object: obj,
406 }) {
407 return
408 }
409 lastParsedObject = obj
410 }
411 }
412
413 case fantasy.StreamPartTypeError:
414 streamErr = part.Error
415 if !yield(fantasy.ObjectStreamPart{
416 Type: fantasy.ObjectStreamPartTypeError,
417 Error: part.Error,
418 }) {
419 return
420 }
421
422 case fantasy.StreamPartTypeFinish:
423 usage = part.Usage
424 finishReason = part.FinishReason
425
426 case fantasy.StreamPartTypeWarnings:
427 warnings = part.Warnings
428 }
429
430 if len(part.ProviderMetadata) > 0 {
431 providerMetadata = part.ProviderMetadata
432 }
433 }
434
435 if streamErr == nil && lastParsedObject != nil {
436 yield(fantasy.ObjectStreamPart{
437 Type: fantasy.ObjectStreamPartTypeFinish,
438 Usage: usage,
439 FinishReason: finishReason,
440 Warnings: warnings,
441 ProviderMetadata: providerMetadata,
442 })
443 } else if streamErr == nil && lastParsedObject == nil {
444 yield(fantasy.ObjectStreamPart{
445 Type: fantasy.ObjectStreamPartTypeError,
446 Error: &fantasy.NoObjectGeneratedError{
447 RawText: accumulated,
448 ParseError: fmt.Errorf("no valid object generated in stream"),
449 Usage: usage,
450 FinishReason: finishReason,
451 },
452 })
453 }
454 }, nil
455}
456
457// StreamWithText is a helper for providers without tool or JSON streaming support.
458// It adds the schema to the system prompt and parses the streamed text as JSON progressively.
459func StreamWithText(
460 ctx context.Context,
461 model fantasy.LanguageModel,
462 call fantasy.ObjectCall,
463) (fantasy.ObjectStreamResponse, error) {
464 jsonSchemaMap := schema.ToMap(call.Schema)
465 jsonSchemaBytes, err := json.Marshal(jsonSchemaMap)
466 if err != nil {
467 return nil, fmt.Errorf("failed to marshal schema: %w", err)
468 }
469
470 schemaInstruction := fmt.Sprintf(
471 "You must respond with valid JSON that matches this schema: %s\n"+
472 "Respond ONLY with the JSON object, no additional text or explanation.",
473 string(jsonSchemaBytes),
474 )
475
476 enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
477
478 hasSystem := false
479 for _, msg := range call.Prompt {
480 if msg.Role == fantasy.MessageRoleSystem {
481 hasSystem = true
482 existingText := ""
483 if len(msg.Content) > 0 {
484 if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
485 existingText = textPart.Text
486 }
487 }
488 enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
489 } else {
490 enhancedPrompt = append(enhancedPrompt, msg)
491 }
492 }
493
494 if !hasSystem {
495 enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
496 }
497
498 stream, err := model.Stream(ctx, fantasy.Call{
499 Prompt: enhancedPrompt,
500 MaxOutputTokens: call.MaxOutputTokens,
501 Temperature: call.Temperature,
502 TopP: call.TopP,
503 TopK: call.TopK,
504 PresencePenalty: call.PresencePenalty,
505 FrequencyPenalty: call.FrequencyPenalty,
506 ProviderOptions: call.ProviderOptions,
507 })
508 if err != nil {
509 return nil, fmt.Errorf("text-based streaming failed: %w", err)
510 }
511
512 return func(yield func(fantasy.ObjectStreamPart) bool) {
513 var accumulated string
514 var lastParsedObject any
515 var usage fantasy.Usage
516 var finishReason fantasy.FinishReason
517 var warnings []fantasy.CallWarning
518 var providerMetadata fantasy.ProviderMetadata
519 var streamErr error
520
521 for part := range stream {
522 switch part.Type {
523 case fantasy.StreamPartTypeTextDelta:
524 accumulated += part.Delta
525
526 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
527
528 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
529 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
530 if !reflect.DeepEqual(obj, lastParsedObject) {
531 if !yield(fantasy.ObjectStreamPart{
532 Type: fantasy.ObjectStreamPartTypeObject,
533 Object: obj,
534 }) {
535 return
536 }
537 lastParsedObject = obj
538 }
539 }
540 }
541
542 if state == schema.ParseStateFailed && call.RepairText != nil {
543 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
544 if repairErr == nil {
545 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
546 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
547 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
548 if !reflect.DeepEqual(obj2, lastParsedObject) {
549 if !yield(fantasy.ObjectStreamPart{
550 Type: fantasy.ObjectStreamPartTypeObject,
551 Object: obj2,
552 }) {
553 return
554 }
555 lastParsedObject = obj2
556 }
557 }
558 }
559 }
560
561 case fantasy.StreamPartTypeError:
562 streamErr = part.Error
563 if !yield(fantasy.ObjectStreamPart{
564 Type: fantasy.ObjectStreamPartTypeError,
565 Error: part.Error,
566 }) {
567 return
568 }
569
570 case fantasy.StreamPartTypeFinish:
571 usage = part.Usage
572 finishReason = part.FinishReason
573
574 case fantasy.StreamPartTypeWarnings:
575 warnings = part.Warnings
576 }
577
578 if len(part.ProviderMetadata) > 0 {
579 providerMetadata = part.ProviderMetadata
580 }
581 }
582
583 if streamErr == nil && lastParsedObject != nil {
584 yield(fantasy.ObjectStreamPart{
585 Type: fantasy.ObjectStreamPartTypeFinish,
586 Usage: usage,
587 FinishReason: finishReason,
588 Warnings: warnings,
589 ProviderMetadata: providerMetadata,
590 })
591 } else if streamErr == nil && lastParsedObject == nil {
592 yield(fantasy.ObjectStreamPart{
593 Type: fantasy.ObjectStreamPartTypeError,
594 Error: &fantasy.NoObjectGeneratedError{
595 RawText: accumulated,
596 ParseError: fmt.Errorf("no valid object generated in stream"),
597 Usage: usage,
598 FinishReason: finishReason,
599 },
600 })
601 }
602 }, nil
603}
604
605func unmarshal(obj any, target any) error {
606 jsonBytes, err := json.Marshal(obj)
607 if err != nil {
608 return fmt.Errorf("failed to marshal object: %w", err)
609 }
610
611 if err := json.Unmarshal(jsonBytes, target); err != nil {
612 return fmt.Errorf("failed to unmarshal into target type: %w", err)
613 }
614
615 return nil
616}