1// Package object provides utilities for generating structured objects with automatic schema generation.
2// It simplifies working with typed structured outputs by handling schema reflection and unmarshaling.
3package object
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "reflect"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/schema"
13)
14
15// Generate generates a structured object that matches the given type T.
16// The schema is automatically generated from T using reflection.
17//
18// Example:
19//
20// type Recipe struct {
21// Name string `json:"name"`
22// Ingredients []string `json:"ingredients"`
23// }
24//
25// result, err := object.Generate[Recipe](ctx, model, fantasy.ObjectCall{
26// Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
27// })
28func Generate[T any](
29 ctx context.Context,
30 model fantasy.LanguageModel,
31 opts fantasy.ObjectCall,
32) (*fantasy.ObjectResult[T], error) {
33 var zero T
34 s := schema.Generate(reflect.TypeOf(zero))
35 opts.Schema = s
36
37 resp, err := model.GenerateObject(ctx, opts)
38 if err != nil {
39 return nil, err
40 }
41
42 var result T
43 if err := unmarshal(resp.Object, &result); err != nil {
44 return nil, fmt.Errorf("failed to unmarshal to %T: %w", result, err)
45 }
46
47 return &fantasy.ObjectResult[T]{
48 Object: result,
49 RawText: resp.RawText,
50 Usage: resp.Usage,
51 FinishReason: resp.FinishReason,
52 Warnings: resp.Warnings,
53 ProviderMetadata: resp.ProviderMetadata,
54 }, nil
55}
56
57// Stream streams a structured object that matches the given type T.
58// Returns a StreamObjectResult[T] with progressive updates and deduplication.
59//
60// Example:
61//
62// stream, err := object.Stream[Recipe](ctx, model, fantasy.ObjectCall{
63// Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a lasagna recipe")},
64// })
65//
66// for partial := range stream.PartialObjectStream() {
67// fmt.Printf("Progress: %s\n", partial.Name)
68// }
69//
70// result, err := stream.Object() // Wait for final result
71func Stream[T any](
72 ctx context.Context,
73 model fantasy.LanguageModel,
74 opts fantasy.ObjectCall,
75) (*fantasy.StreamObjectResult[T], error) {
76 var zero T
77 s := schema.Generate(reflect.TypeOf(zero))
78 opts.Schema = s
79
80 stream, err := model.StreamObject(ctx, opts)
81 if err != nil {
82 return nil, err
83 }
84
85 return fantasy.NewStreamObjectResult[T](ctx, stream), nil
86}
87
88// GenerateWithTool is a helper for providers without native JSON mode.
89// It converts the schema to a tool definition, forces the model to call it,
90// and extracts the tool's input as the structured output.
91func GenerateWithTool(
92 ctx context.Context,
93 model fantasy.LanguageModel,
94 call fantasy.ObjectCall,
95) (*fantasy.ObjectResponse, error) {
96 toolName := call.SchemaName
97 if toolName == "" {
98 toolName = "generate_object"
99 }
100
101 toolDescription := call.SchemaDescription
102 if toolDescription == "" {
103 toolDescription = "Generate a structured object matching the schema"
104 }
105
106 tool := fantasy.FunctionTool{
107 Name: toolName,
108 Description: toolDescription,
109 InputSchema: schema.ToMap(call.Schema),
110 }
111
112 toolChoice := fantasy.SpecificToolChoice(tool.Name)
113 resp, err := model.Generate(ctx, fantasy.Call{
114 Prompt: call.Prompt,
115 Tools: []fantasy.Tool{tool},
116 ToolChoice: &toolChoice,
117 MaxOutputTokens: call.MaxOutputTokens,
118 Temperature: call.Temperature,
119 TopP: call.TopP,
120 TopK: call.TopK,
121 PresencePenalty: call.PresencePenalty,
122 FrequencyPenalty: call.FrequencyPenalty,
123 UserAgent: call.UserAgent,
124 ProviderOptions: call.ProviderOptions,
125 })
126 if err != nil {
127 return nil, fmt.Errorf("tool-based generation failed: %w", err)
128 }
129
130 toolCalls := resp.Content.ToolCalls()
131 if len(toolCalls) == 0 {
132 return nil, &fantasy.NoObjectGeneratedError{
133 RawText: resp.Content.Text(),
134 ParseError: fmt.Errorf("no tool call generated"),
135 Usage: resp.Usage,
136 FinishReason: resp.FinishReason,
137 }
138 }
139
140 toolCall := toolCalls[0]
141
142 var obj any
143 if call.RepairText != nil {
144 obj, err = schema.ParseAndValidateWithRepair(ctx, toolCall.Input, call.Schema, call.RepairText)
145 } else {
146 obj, err = schema.ParseAndValidate(toolCall.Input, call.Schema)
147 }
148
149 if err != nil {
150 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
151 nogErr.Usage = resp.Usage
152 nogErr.FinishReason = resp.FinishReason
153 }
154 return nil, err
155 }
156
157 return &fantasy.ObjectResponse{
158 Object: obj,
159 RawText: toolCall.Input,
160 Usage: resp.Usage,
161 FinishReason: resp.FinishReason,
162 Warnings: resp.Warnings,
163 ProviderMetadata: resp.ProviderMetadata,
164 }, nil
165}
166
167// GenerateWithText is a helper for providers without tool or JSON mode support.
168// It adds the schema to the system prompt and parses the text response as JSON.
169// This is a fallback for older models or simple providers.
170func GenerateWithText(
171 ctx context.Context,
172 model fantasy.LanguageModel,
173 call fantasy.ObjectCall,
174) (*fantasy.ObjectResponse, error) {
175 jsonSchemaBytes, err := json.Marshal(call.Schema)
176 if err != nil {
177 return nil, fmt.Errorf("failed to marshal schema: %w", err)
178 }
179
180 schemaInstruction := fmt.Sprintf(
181 "You must respond with valid JSON that matches this schema: %s\n"+
182 "Respond ONLY with the JSON object, no additional text or explanation.",
183 string(jsonSchemaBytes),
184 )
185
186 enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
187
188 hasSystem := false
189 for _, msg := range call.Prompt {
190 if msg.Role == fantasy.MessageRoleSystem {
191 hasSystem = true
192 existingText := ""
193 if len(msg.Content) > 0 {
194 if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
195 existingText = textPart.Text
196 }
197 }
198 enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
199 } else {
200 enhancedPrompt = append(enhancedPrompt, msg)
201 }
202 }
203
204 if !hasSystem {
205 enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
206 }
207
208 resp, err := model.Generate(ctx, fantasy.Call{
209 Prompt: enhancedPrompt,
210 MaxOutputTokens: call.MaxOutputTokens,
211 Temperature: call.Temperature,
212 TopP: call.TopP,
213 TopK: call.TopK,
214 PresencePenalty: call.PresencePenalty,
215 FrequencyPenalty: call.FrequencyPenalty,
216 UserAgent: call.UserAgent,
217 ProviderOptions: call.ProviderOptions,
218 })
219 if err != nil {
220 return nil, fmt.Errorf("text-based generation failed: %w", err)
221 }
222
223 textContent := resp.Content.Text()
224 if textContent == "" {
225 return nil, &fantasy.NoObjectGeneratedError{
226 RawText: "",
227 ParseError: fmt.Errorf("no text content in response"),
228 Usage: resp.Usage,
229 FinishReason: resp.FinishReason,
230 }
231 }
232
233 var obj any
234 if call.RepairText != nil {
235 obj, err = schema.ParseAndValidateWithRepair(ctx, textContent, call.Schema, call.RepairText)
236 } else {
237 obj, err = schema.ParseAndValidate(textContent, call.Schema)
238 }
239
240 if err != nil {
241 if nogErr, ok := err.(*schema.ParseError); ok {
242 return nil, &fantasy.NoObjectGeneratedError{
243 RawText: nogErr.RawText,
244 ParseError: nogErr.ParseError,
245 ValidationError: nogErr.ValidationError,
246 Usage: resp.Usage,
247 FinishReason: resp.FinishReason,
248 }
249 }
250 return nil, err
251 }
252
253 return &fantasy.ObjectResponse{
254 Object: obj,
255 RawText: textContent,
256 Usage: resp.Usage,
257 FinishReason: resp.FinishReason,
258 Warnings: resp.Warnings,
259 ProviderMetadata: resp.ProviderMetadata,
260 }, nil
261}
262
263// StreamWithTool is a helper for providers without native JSON streaming.
264// It uses streaming tool calls to extract and parse the structured output progressively.
265func StreamWithTool(
266 ctx context.Context,
267 model fantasy.LanguageModel,
268 call fantasy.ObjectCall,
269) (fantasy.ObjectStreamResponse, error) {
270 // Create a tool from the schema
271 toolName := call.SchemaName
272 if toolName == "" {
273 toolName = "generate_object"
274 }
275
276 toolDescription := call.SchemaDescription
277 if toolDescription == "" {
278 toolDescription = "Generate a structured object matching the schema"
279 }
280
281 tool := fantasy.FunctionTool{
282 Name: toolName,
283 Description: toolDescription,
284 InputSchema: schema.ToMap(call.Schema),
285 }
286
287 // Make a streaming Generate call with forced tool choice
288 toolChoice := fantasy.SpecificToolChoice(tool.Name)
289 stream, err := model.Stream(ctx, fantasy.Call{
290 Prompt: call.Prompt,
291 Tools: []fantasy.Tool{tool},
292 ToolChoice: &toolChoice,
293 MaxOutputTokens: call.MaxOutputTokens,
294 Temperature: call.Temperature,
295 TopP: call.TopP,
296 TopK: call.TopK,
297 PresencePenalty: call.PresencePenalty,
298 FrequencyPenalty: call.FrequencyPenalty,
299 UserAgent: call.UserAgent,
300 ProviderOptions: call.ProviderOptions,
301 })
302 if err != nil {
303 return nil, fmt.Errorf("tool-based streaming failed: %w", err)
304 }
305
306 // Convert the text stream to object stream parts
307 return func(yield func(fantasy.ObjectStreamPart) bool) {
308 var accumulated string
309 var lastParsedObject any
310 var usage fantasy.Usage
311 var finishReason fantasy.FinishReason
312 var warnings []fantasy.CallWarning
313 var providerMetadata fantasy.ProviderMetadata
314 var streamErr error
315
316 for part := range stream {
317 switch part.Type {
318 case fantasy.StreamPartTypeTextDelta:
319 accumulated += part.Delta
320
321 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
322
323 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
324 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
325 if !reflect.DeepEqual(obj, lastParsedObject) {
326 if !yield(fantasy.ObjectStreamPart{
327 Type: fantasy.ObjectStreamPartTypeObject,
328 Object: obj,
329 }) {
330 return
331 }
332 lastParsedObject = obj
333 }
334 }
335 }
336
337 if state == schema.ParseStateFailed && call.RepairText != nil {
338 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
339 if repairErr == nil {
340 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
341 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
342 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
343 if !reflect.DeepEqual(obj2, lastParsedObject) {
344 if !yield(fantasy.ObjectStreamPart{
345 Type: fantasy.ObjectStreamPartTypeObject,
346 Object: obj2,
347 }) {
348 return
349 }
350 lastParsedObject = obj2
351 }
352 }
353 }
354 }
355
356 case fantasy.StreamPartTypeToolInputDelta:
357 accumulated += part.Delta
358
359 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
360 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
361 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
362 if !reflect.DeepEqual(obj, lastParsedObject) {
363 if !yield(fantasy.ObjectStreamPart{
364 Type: fantasy.ObjectStreamPartTypeObject,
365 Object: obj,
366 }) {
367 return
368 }
369 lastParsedObject = obj
370 }
371 }
372 }
373
374 if state == schema.ParseStateFailed && call.RepairText != nil {
375 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
376 if repairErr == nil {
377 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
378 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
379 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
380 if !reflect.DeepEqual(obj2, lastParsedObject) {
381 if !yield(fantasy.ObjectStreamPart{
382 Type: fantasy.ObjectStreamPartTypeObject,
383 Object: obj2,
384 }) {
385 return
386 }
387 lastParsedObject = obj2
388 }
389 }
390 }
391 }
392
393 case fantasy.StreamPartTypeToolCall:
394 toolInput := part.ToolCallInput
395
396 var obj any
397 var err error
398 if call.RepairText != nil {
399 obj, err = schema.ParseAndValidateWithRepair(ctx, toolInput, call.Schema, call.RepairText)
400 } else {
401 obj, err = schema.ParseAndValidate(toolInput, call.Schema)
402 }
403
404 if err == nil {
405 if !reflect.DeepEqual(obj, lastParsedObject) {
406 if !yield(fantasy.ObjectStreamPart{
407 Type: fantasy.ObjectStreamPartTypeObject,
408 Object: obj,
409 }) {
410 return
411 }
412 lastParsedObject = obj
413 }
414 }
415
416 case fantasy.StreamPartTypeError:
417 streamErr = part.Error
418 if !yield(fantasy.ObjectStreamPart{
419 Type: fantasy.ObjectStreamPartTypeError,
420 Error: part.Error,
421 }) {
422 return
423 }
424
425 case fantasy.StreamPartTypeFinish:
426 usage = part.Usage
427 finishReason = part.FinishReason
428
429 case fantasy.StreamPartTypeWarnings:
430 warnings = part.Warnings
431 }
432
433 if len(part.ProviderMetadata) > 0 {
434 providerMetadata = part.ProviderMetadata
435 }
436 }
437
438 if streamErr == nil && lastParsedObject != nil {
439 yield(fantasy.ObjectStreamPart{
440 Type: fantasy.ObjectStreamPartTypeFinish,
441 Usage: usage,
442 FinishReason: finishReason,
443 Warnings: warnings,
444 ProviderMetadata: providerMetadata,
445 })
446 } else if streamErr == nil && lastParsedObject == nil {
447 yield(fantasy.ObjectStreamPart{
448 Type: fantasy.ObjectStreamPartTypeError,
449 Error: &fantasy.NoObjectGeneratedError{
450 RawText: accumulated,
451 ParseError: fmt.Errorf("no valid object generated in stream"),
452 Usage: usage,
453 FinishReason: finishReason,
454 },
455 })
456 }
457 }, nil
458}
459
460// StreamWithText is a helper for providers without tool or JSON streaming support.
461// It adds the schema to the system prompt and parses the streamed text as JSON progressively.
462func StreamWithText(
463 ctx context.Context,
464 model fantasy.LanguageModel,
465 call fantasy.ObjectCall,
466) (fantasy.ObjectStreamResponse, error) {
467 jsonSchemaMap := schema.ToMap(call.Schema)
468 jsonSchemaBytes, err := json.Marshal(jsonSchemaMap)
469 if err != nil {
470 return nil, fmt.Errorf("failed to marshal schema: %w", err)
471 }
472
473 schemaInstruction := fmt.Sprintf(
474 "You must respond with valid JSON that matches this schema: %s\n"+
475 "Respond ONLY with the JSON object, no additional text or explanation.",
476 string(jsonSchemaBytes),
477 )
478
479 enhancedPrompt := make(fantasy.Prompt, 0, len(call.Prompt)+1)
480
481 hasSystem := false
482 for _, msg := range call.Prompt {
483 if msg.Role == fantasy.MessageRoleSystem {
484 hasSystem = true
485 existingText := ""
486 if len(msg.Content) > 0 {
487 if textPart, ok := msg.Content[0].(fantasy.TextPart); ok {
488 existingText = textPart.Text
489 }
490 }
491 enhancedPrompt = append(enhancedPrompt, fantasy.NewSystemMessage(existingText+"\n\n"+schemaInstruction))
492 } else {
493 enhancedPrompt = append(enhancedPrompt, msg)
494 }
495 }
496
497 if !hasSystem {
498 enhancedPrompt = append(fantasy.Prompt{fantasy.NewSystemMessage(schemaInstruction)}, call.Prompt...)
499 }
500
501 stream, err := model.Stream(ctx, fantasy.Call{
502 Prompt: enhancedPrompt,
503 MaxOutputTokens: call.MaxOutputTokens,
504 Temperature: call.Temperature,
505 TopP: call.TopP,
506 TopK: call.TopK,
507 PresencePenalty: call.PresencePenalty,
508 FrequencyPenalty: call.FrequencyPenalty,
509 UserAgent: call.UserAgent,
510 ProviderOptions: call.ProviderOptions,
511 })
512 if err != nil {
513 return nil, fmt.Errorf("text-based streaming failed: %w", err)
514 }
515
516 return func(yield func(fantasy.ObjectStreamPart) bool) {
517 var accumulated string
518 var lastParsedObject any
519 var usage fantasy.Usage
520 var finishReason fantasy.FinishReason
521 var warnings []fantasy.CallWarning
522 var providerMetadata fantasy.ProviderMetadata
523 var streamErr error
524
525 for part := range stream {
526 switch part.Type {
527 case fantasy.StreamPartTypeTextDelta:
528 accumulated += part.Delta
529
530 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
531
532 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
533 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
534 if !reflect.DeepEqual(obj, lastParsedObject) {
535 if !yield(fantasy.ObjectStreamPart{
536 Type: fantasy.ObjectStreamPartTypeObject,
537 Object: obj,
538 }) {
539 return
540 }
541 lastParsedObject = obj
542 }
543 }
544 }
545
546 if state == schema.ParseStateFailed && call.RepairText != nil {
547 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
548 if repairErr == nil {
549 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
550 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
551 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
552 if !reflect.DeepEqual(obj2, lastParsedObject) {
553 if !yield(fantasy.ObjectStreamPart{
554 Type: fantasy.ObjectStreamPartTypeObject,
555 Object: obj2,
556 }) {
557 return
558 }
559 lastParsedObject = obj2
560 }
561 }
562 }
563 }
564
565 case fantasy.StreamPartTypeError:
566 streamErr = part.Error
567 if !yield(fantasy.ObjectStreamPart{
568 Type: fantasy.ObjectStreamPartTypeError,
569 Error: part.Error,
570 }) {
571 return
572 }
573
574 case fantasy.StreamPartTypeFinish:
575 usage = part.Usage
576 finishReason = part.FinishReason
577
578 case fantasy.StreamPartTypeWarnings:
579 warnings = part.Warnings
580 }
581
582 if len(part.ProviderMetadata) > 0 {
583 providerMetadata = part.ProviderMetadata
584 }
585 }
586
587 if streamErr == nil && lastParsedObject != nil {
588 yield(fantasy.ObjectStreamPart{
589 Type: fantasy.ObjectStreamPartTypeFinish,
590 Usage: usage,
591 FinishReason: finishReason,
592 Warnings: warnings,
593 ProviderMetadata: providerMetadata,
594 })
595 } else if streamErr == nil && lastParsedObject == nil {
596 yield(fantasy.ObjectStreamPart{
597 Type: fantasy.ObjectStreamPartTypeError,
598 Error: &fantasy.NoObjectGeneratedError{
599 RawText: accumulated,
600 ParseError: fmt.Errorf("no valid object generated in stream"),
601 Usage: usage,
602 FinishReason: finishReason,
603 },
604 })
605 }
606 }, nil
607}
608
609func unmarshal(obj any, target any) error {
610 jsonBytes, err := json.Marshal(obj)
611 if err != nil {
612 return fmt.Errorf("failed to marshal object: %w", err)
613 }
614
615 if err := json.Unmarshal(jsonBytes, target); err != nil {
616 return fmt.Errorf("failed to unmarshal into target type: %w", err)
617 }
618
619 return nil
620}