1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "strings"
11
12 "charm.land/fantasy"
13 "charm.land/fantasy/object"
14 "charm.land/fantasy/schema"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type languageModel struct {
23 provider string
24 modelID string
25 client openai.Client
26 objectMode fantasy.ObjectMode
27 prepareCallFunc LanguageModelPrepareCallFunc
28 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
29 extraContentFunc LanguageModelExtraContentFunc
30 usageFunc LanguageModelUsageFunc
31 streamUsageFunc LanguageModelStreamUsageFunc
32 streamExtraFunc LanguageModelStreamExtraFunc
33 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
34 toPromptFunc LanguageModelToPromptFunc
35}
36
37// LanguageModelOption is a function that configures a languageModel.
38type LanguageModelOption = func(*languageModel)
39
40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
42 return func(l *languageModel) {
43 l.prepareCallFunc = fn
44 }
45}
46
47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
49 return func(l *languageModel) {
50 l.mapFinishReasonFunc = fn
51 }
52}
53
54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
56 return func(l *languageModel) {
57 l.extraContentFunc = fn
58 }
59}
60
61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
63 return func(l *languageModel) {
64 l.streamExtraFunc = fn
65 }
66}
67
68// WithLanguageModelUsageFunc sets the usage function for the language model.
69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
70 return func(l *languageModel) {
71 l.usageFunc = fn
72 }
73}
74
75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
77 return func(l *languageModel) {
78 l.streamUsageFunc = fn
79 }
80}
81
82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
84 return func(l *languageModel) {
85 l.toPromptFunc = fn
86 }
87}
88
89// WithLanguageModelObjectMode sets the object generation mode.
90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
91 return func(l *languageModel) {
92 // not supported
93 if om == fantasy.ObjectModeJSON {
94 om = fantasy.ObjectModeAuto
95 }
96 l.objectMode = om
97 }
98}
99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101 model := languageModel{
102 modelID: modelID,
103 provider: provider,
104 client: client,
105 objectMode: fantasy.ObjectModeAuto,
106 prepareCallFunc: DefaultPrepareCallFunc,
107 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
108 usageFunc: DefaultUsageFunc,
109 streamUsageFunc: DefaultStreamUsageFunc,
110 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111 toPromptFunc: DefaultToPrompt,
112 }
113
114 for _, o := range opts {
115 o(&model)
116 }
117 return model
118}
119
120type streamToolCall struct {
121 id string
122 name string
123 arguments string
124 hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129 return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134 return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138 params := &openai.ChatCompletionNewParams{}
139 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140 if call.TopK != nil {
141 warnings = append(warnings, fantasy.CallWarning{
142 Type: fantasy.CallWarningTypeUnsupportedSetting,
143 Setting: "top_k",
144 })
145 }
146
147 if call.MaxOutputTokens != nil {
148 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149 }
150 if call.Temperature != nil {
151 params.Temperature = param.NewOpt(*call.Temperature)
152 }
153 if call.TopP != nil {
154 params.TopP = param.NewOpt(*call.TopP)
155 }
156 if call.FrequencyPenalty != nil {
157 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158 }
159 if call.PresencePenalty != nil {
160 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161 }
162
163 if isReasoningModel(o.modelID) {
164 // remove unsupported settings for reasoning models
165 // see https://platform.openai.com/docs/guides/reasoning#limitations
166 if call.Temperature != nil {
167 params.Temperature = param.Opt[float64]{}
168 warnings = append(warnings, fantasy.CallWarning{
169 Type: fantasy.CallWarningTypeUnsupportedSetting,
170 Setting: "temperature",
171 Details: "temperature is not supported for reasoning models",
172 })
173 }
174 if call.TopP != nil {
175 params.TopP = param.Opt[float64]{}
176 warnings = append(warnings, fantasy.CallWarning{
177 Type: fantasy.CallWarningTypeUnsupportedSetting,
178 Setting: "TopP",
179 Details: "TopP is not supported for reasoning models",
180 })
181 }
182 if call.FrequencyPenalty != nil {
183 params.FrequencyPenalty = param.Opt[float64]{}
184 warnings = append(warnings, fantasy.CallWarning{
185 Type: fantasy.CallWarningTypeUnsupportedSetting,
186 Setting: "FrequencyPenalty",
187 Details: "FrequencyPenalty is not supported for reasoning models",
188 })
189 }
190 if call.PresencePenalty != nil {
191 params.PresencePenalty = param.Opt[float64]{}
192 warnings = append(warnings, fantasy.CallWarning{
193 Type: fantasy.CallWarningTypeUnsupportedSetting,
194 Setting: "PresencePenalty",
195 Details: "PresencePenalty is not supported for reasoning models",
196 })
197 }
198
199 // reasoning models use max_completion_tokens instead of max_tokens
200 if call.MaxOutputTokens != nil {
201 if !params.MaxCompletionTokens.Valid() {
202 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203 }
204 params.MaxTokens = param.Opt[int64]{}
205 }
206 }
207
208 // Handle search preview models
209 if isSearchPreviewModel(o.modelID) {
210 if call.Temperature != nil {
211 params.Temperature = param.Opt[float64]{}
212 warnings = append(warnings, fantasy.CallWarning{
213 Type: fantasy.CallWarningTypeUnsupportedSetting,
214 Setting: "temperature",
215 Details: "temperature is not supported for the search preview models and has been removed.",
216 })
217 }
218 }
219
220 optionsWarnings, err := o.prepareCallFunc(o, params, call)
221 if err != nil {
222 return nil, nil, err
223 }
224
225 if len(optionsWarnings) > 0 {
226 warnings = append(warnings, optionsWarnings...)
227 }
228
229 params.Messages = messages
230 params.Model = o.modelID
231
232 if len(call.Tools) > 0 {
233 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234 params.Tools = tools
235 if toolChoice != nil {
236 params.ToolChoice = *toolChoice
237 }
238 warnings = append(warnings, toolWarnings...)
239 }
240 return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245 params, warnings, err := o.prepareParams(call)
246 if err != nil {
247 return nil, err
248 }
249 response, err := o.client.Chat.Completions.New(ctx, *params)
250 if err != nil {
251 return nil, toProviderErr(err)
252 }
253
254 if len(response.Choices) == 0 {
255 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256 }
257 choice := response.Choices[0]
258 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259 text := choice.Message.Content
260 if text != "" {
261 content = append(content, fantasy.TextContent{
262 Text: text,
263 })
264 }
265 if o.extraContentFunc != nil {
266 extraContent := o.extraContentFunc(choice)
267 content = append(content, extraContent...)
268 }
269 for _, tc := range choice.Message.ToolCalls {
270 toolCallID := tc.ID
271 content = append(content, fantasy.ToolCallContent{
272 ProviderExecuted: false,
273 ToolCallID: toolCallID,
274 ToolName: tc.Function.Name,
275 Input: tc.Function.Arguments,
276 })
277 }
278 for _, annotation := range choice.Message.Annotations {
279 if annotation.Type == "url_citation" {
280 content = append(content, fantasy.SourceContent{
281 SourceType: fantasy.SourceTypeURL,
282 ID: uuid.NewString(),
283 URL: annotation.URLCitation.URL,
284 Title: annotation.URLCitation.Title,
285 })
286 }
287 }
288
289 usage, providerMetadata := o.usageFunc(*response)
290
291 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292 if len(choice.Message.ToolCalls) > 0 {
293 mappedFinishReason = fantasy.FinishReasonToolCalls
294 }
295 return &fantasy.Response{
296 Content: content,
297 Usage: usage,
298 FinishReason: mappedFinishReason,
299 ProviderMetadata: fantasy.ProviderMetadata{
300 Name: providerMetadata,
301 },
302 Warnings: warnings,
303 }, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308 params, warnings, err := o.prepareParams(call)
309 if err != nil {
310 return nil, err
311 }
312
313 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314 IncludeUsage: openai.Bool(true),
315 }
316
317 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318 isActiveText := false
319 toolCalls := make(map[int64]streamToolCall)
320
321 providerMetadata := fantasy.ProviderMetadata{
322 Name: &ProviderMetadata{},
323 }
324 acc := openai.ChatCompletionAccumulator{}
325 extraContext := make(map[string]any)
326 var usage fantasy.Usage
327 var finishReason string
328 return func(yield func(fantasy.StreamPart) bool) {
329 if len(warnings) > 0 {
330 if !yield(fantasy.StreamPart{
331 Type: fantasy.StreamPartTypeWarnings,
332 Warnings: warnings,
333 }) {
334 return
335 }
336 }
337 for stream.Next() {
338 chunk := stream.Current()
339 acc.AddChunk(chunk)
340 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341 if len(chunk.Choices) == 0 {
342 continue
343 }
344 for _, choice := range chunk.Choices {
345 if choice.FinishReason != "" {
346 finishReason = choice.FinishReason
347 }
348 switch {
349 case choice.Delta.Content != "":
350 if !isActiveText {
351 isActiveText = true
352 if !yield(fantasy.StreamPart{
353 Type: fantasy.StreamPartTypeTextStart,
354 ID: "0",
355 }) {
356 return
357 }
358 }
359 if !yield(fantasy.StreamPart{
360 Type: fantasy.StreamPartTypeTextDelta,
361 ID: "0",
362 Delta: choice.Delta.Content,
363 }) {
364 return
365 }
366 case len(choice.Delta.ToolCalls) > 0:
367 if isActiveText {
368 isActiveText = false
369 if !yield(fantasy.StreamPart{
370 Type: fantasy.StreamPartTypeTextEnd,
371 ID: "0",
372 }) {
373 return
374 }
375 }
376
377 for _, toolCallDelta := range choice.Delta.ToolCalls {
378 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379 if existingToolCall.hasFinished {
380 continue
381 }
382 if toolCallDelta.Function.Arguments != "" {
383 existingToolCall.arguments += toolCallDelta.Function.Arguments
384 }
385 if !yield(fantasy.StreamPart{
386 Type: fantasy.StreamPartTypeToolInputDelta,
387 ID: existingToolCall.id,
388 Delta: toolCallDelta.Function.Arguments,
389 }) {
390 return
391 }
392 toolCalls[toolCallDelta.Index] = existingToolCall
393 if xjson.IsValid(existingToolCall.arguments) {
394 if !yield(fantasy.StreamPart{
395 Type: fantasy.StreamPartTypeToolInputEnd,
396 ID: existingToolCall.id,
397 }) {
398 return
399 }
400
401 if !yield(fantasy.StreamPart{
402 Type: fantasy.StreamPartTypeToolCall,
403 ID: existingToolCall.id,
404 ToolCallName: existingToolCall.name,
405 ToolCallInput: existingToolCall.arguments,
406 }) {
407 return
408 }
409 existingToolCall.hasFinished = true
410 toolCalls[toolCallDelta.Index] = existingToolCall
411 }
412 } else {
413 var err error
414 if toolCallDelta.Type != "function" {
415 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416 }
417 if toolCallDelta.ID == "" {
418 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419 }
420 if toolCallDelta.Function.Name == "" {
421 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422 }
423 if err != nil {
424 yield(fantasy.StreamPart{
425 Type: fantasy.StreamPartTypeError,
426 Error: toProviderErr(stream.Err()),
427 })
428 return
429 }
430
431 if !yield(fantasy.StreamPart{
432 Type: fantasy.StreamPartTypeToolInputStart,
433 ID: toolCallDelta.ID,
434 ToolCallName: toolCallDelta.Function.Name,
435 }) {
436 return
437 }
438 toolCalls[toolCallDelta.Index] = streamToolCall{
439 id: toolCallDelta.ID,
440 name: toolCallDelta.Function.Name,
441 arguments: toolCallDelta.Function.Arguments,
442 }
443
444 exTc := toolCalls[toolCallDelta.Index]
445 if exTc.arguments != "" {
446 if !yield(fantasy.StreamPart{
447 Type: fantasy.StreamPartTypeToolInputDelta,
448 ID: exTc.id,
449 Delta: exTc.arguments,
450 }) {
451 return
452 }
453 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454 if !yield(fantasy.StreamPart{
455 Type: fantasy.StreamPartTypeToolInputEnd,
456 ID: toolCallDelta.ID,
457 }) {
458 return
459 }
460
461 if !yield(fantasy.StreamPart{
462 Type: fantasy.StreamPartTypeToolCall,
463 ID: exTc.id,
464 ToolCallName: exTc.name,
465 ToolCallInput: exTc.arguments,
466 }) {
467 return
468 }
469 exTc.hasFinished = true
470 toolCalls[toolCallDelta.Index] = exTc
471 }
472 }
473 continue
474 }
475 }
476 }
477
478 if o.streamExtraFunc != nil {
479 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480 if !shouldContinue {
481 return
482 }
483 extraContext = updatedContext
484 }
485 }
486
487 for _, choice := range chunk.Choices {
488 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489 for _, annotation := range annotations {
490 if annotation.Type == "url_citation" {
491 if !yield(fantasy.StreamPart{
492 Type: fantasy.StreamPartTypeSource,
493 ID: uuid.NewString(),
494 SourceType: fantasy.SourceTypeURL,
495 URL: annotation.URLCitation.URL,
496 Title: annotation.URLCitation.Title,
497 }) {
498 return
499 }
500 }
501 }
502 }
503 }
504 }
505 err := stream.Err()
506 if err == nil || errors.Is(err, io.EOF) {
507 if isActiveText {
508 isActiveText = false
509 if !yield(fantasy.StreamPart{
510 Type: fantasy.StreamPartTypeTextEnd,
511 ID: "0",
512 }) {
513 return
514 }
515 }
516
517 if len(acc.Choices) > 0 {
518 choice := acc.Choices[0]
519 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521 for _, annotation := range choice.Message.Annotations {
522 if annotation.Type == "url_citation" {
523 if !yield(fantasy.StreamPart{
524 Type: fantasy.StreamPartTypeSource,
525 ID: acc.ID,
526 SourceType: fantasy.SourceTypeURL,
527 URL: annotation.URLCitation.URL,
528 Title: annotation.URLCitation.Title,
529 }) {
530 return
531 }
532 }
533 }
534 }
535 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536 if len(acc.Choices) > 0 {
537 choice := acc.Choices[0]
538 if len(choice.Message.ToolCalls) > 0 {
539 mappedFinishReason = fantasy.FinishReasonToolCalls
540 }
541 }
542 yield(fantasy.StreamPart{
543 Type: fantasy.StreamPartTypeFinish,
544 Usage: usage,
545 FinishReason: mappedFinishReason,
546 ProviderMetadata: providerMetadata,
547 })
548 return
549 } else { //nolint: revive
550 yield(fantasy.StreamPart{
551 Type: fantasy.StreamPartTypeError,
552 Error: toProviderErr(err),
553 })
554 return
555 }
556 }, nil
557}
558
559func isReasoningModel(modelID string) bool {
560 return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561 strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562 strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563 strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564 strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568 return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572 return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573 strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577 return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578 strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579 strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583 for _, tool := range tools {
584 if tool.GetType() == fantasy.ToolTypeFunction {
585 ft, ok := tool.(fantasy.FunctionTool)
586 if !ok {
587 continue
588 }
589 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590 OfFunction: &openai.ChatCompletionFunctionToolParam{
591 Function: shared.FunctionDefinitionParam{
592 Name: ft.Name,
593 Description: param.NewOpt(ft.Description),
594 Parameters: openai.FunctionParameters(ft.InputSchema),
595 Strict: param.NewOpt(false),
596 },
597 Type: "function",
598 },
599 })
600 continue
601 }
602
603 warnings = append(warnings, fantasy.CallWarning{
604 Type: fantasy.CallWarningTypeUnsupportedTool,
605 Tool: tool,
606 Message: "tool is not supported",
607 })
608 }
609 if toolChoice == nil {
610 return openAiTools, openAiToolChoice, warnings
611 }
612
613 switch *toolChoice {
614 case fantasy.ToolChoiceAuto:
615 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616 OfAuto: param.NewOpt("auto"),
617 }
618 case fantasy.ToolChoiceNone:
619 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620 OfAuto: param.NewOpt("none"),
621 }
622 case fantasy.ToolChoiceRequired:
623 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624 OfAuto: param.NewOpt("required"),
625 }
626 default:
627 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
628 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
629 Type: "function",
630 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
631 Name: string(*toolChoice),
632 },
633 },
634 }
635 }
636 return openAiTools, openAiToolChoice, warnings
637}
638
639// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
640func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
641 var annotations []openai.ChatCompletionMessageAnnotation
642
643 // Parse the raw JSON to extract annotations
644 var deltaData map[string]any
645 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
646 return annotations
647 }
648
649 // Check if annotations exist in the delta
650 if annotationsData, ok := deltaData["annotations"].([]any); ok {
651 for _, annotationData := range annotationsData {
652 if annotationMap, ok := annotationData.(map[string]any); ok {
653 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
654 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
655 url, urlOk := urlCitationData["url"].(string)
656 title, titleOk := urlCitationData["title"].(string)
657 if urlOk && titleOk {
658 annotation := openai.ChatCompletionMessageAnnotation{
659 Type: "url_citation",
660 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
661 URL: url,
662 Title: title,
663 },
664 }
665 annotations = append(annotations, annotation)
666 }
667 }
668 }
669 }
670 }
671 }
672
673 return annotations
674}
675
676// GenerateObject implements fantasy.LanguageModel.
677func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
678 switch o.objectMode {
679 case fantasy.ObjectModeText:
680 return object.GenerateWithText(ctx, o, call)
681 case fantasy.ObjectModeTool:
682 return object.GenerateWithTool(ctx, o, call)
683 default:
684 return o.generateObjectWithJSONMode(ctx, call)
685 }
686}
687
688// StreamObject implements fantasy.LanguageModel.
689func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
690 switch o.objectMode {
691 case fantasy.ObjectModeTool:
692 return object.StreamWithTool(ctx, o, call)
693 case fantasy.ObjectModeText:
694 return object.StreamWithText(ctx, o, call)
695 default:
696 return o.streamObjectWithJSONMode(ctx, call)
697 }
698}
699
700func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
701 jsonSchemaMap := schema.ToMap(call.Schema)
702
703 addAdditionalPropertiesFalse(jsonSchemaMap)
704
705 schemaName := call.SchemaName
706 if schemaName == "" {
707 schemaName = "response"
708 }
709
710 fantasyCall := fantasy.Call{
711 Prompt: call.Prompt,
712 MaxOutputTokens: call.MaxOutputTokens,
713 Temperature: call.Temperature,
714 TopP: call.TopP,
715 PresencePenalty: call.PresencePenalty,
716 FrequencyPenalty: call.FrequencyPenalty,
717 ProviderOptions: call.ProviderOptions,
718 }
719
720 params, warnings, err := o.prepareParams(fantasyCall)
721 if err != nil {
722 return nil, err
723 }
724
725 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
726 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
727 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
728 Name: schemaName,
729 Description: param.NewOpt(call.SchemaDescription),
730 Schema: jsonSchemaMap,
731 Strict: param.NewOpt(true),
732 },
733 },
734 }
735
736 response, err := o.client.Chat.Completions.New(ctx, *params)
737 if err != nil {
738 return nil, toProviderErr(err)
739 }
740
741 if len(response.Choices) == 0 {
742 usage, _ := o.usageFunc(*response)
743 return nil, &fantasy.NoObjectGeneratedError{
744 RawText: "",
745 ParseError: fmt.Errorf("no choices in response"),
746 Usage: usage,
747 FinishReason: fantasy.FinishReasonUnknown,
748 }
749 }
750
751 choice := response.Choices[0]
752 jsonText := choice.Message.Content
753
754 var obj any
755 if call.RepairText != nil {
756 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
757 } else {
758 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
759 }
760
761 usage, _ := o.usageFunc(*response)
762 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
763
764 if err != nil {
765 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
766 nogErr.Usage = usage
767 nogErr.FinishReason = finishReason
768 }
769 return nil, err
770 }
771
772 return &fantasy.ObjectResponse{
773 Object: obj,
774 RawText: jsonText,
775 Usage: usage,
776 FinishReason: finishReason,
777 Warnings: warnings,
778 }, nil
779}
780
781func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
782 jsonSchemaMap := schema.ToMap(call.Schema)
783
784 addAdditionalPropertiesFalse(jsonSchemaMap)
785
786 schemaName := call.SchemaName
787 if schemaName == "" {
788 schemaName = "response"
789 }
790
791 fantasyCall := fantasy.Call{
792 Prompt: call.Prompt,
793 MaxOutputTokens: call.MaxOutputTokens,
794 Temperature: call.Temperature,
795 TopP: call.TopP,
796 PresencePenalty: call.PresencePenalty,
797 FrequencyPenalty: call.FrequencyPenalty,
798 ProviderOptions: call.ProviderOptions,
799 }
800
801 params, warnings, err := o.prepareParams(fantasyCall)
802 if err != nil {
803 return nil, err
804 }
805
806 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
807 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
808 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
809 Name: schemaName,
810 Description: param.NewOpt(call.SchemaDescription),
811 Schema: jsonSchemaMap,
812 Strict: param.NewOpt(true),
813 },
814 },
815 }
816
817 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
818 IncludeUsage: openai.Bool(true),
819 }
820
821 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
822
823 return func(yield func(fantasy.ObjectStreamPart) bool) {
824 if len(warnings) > 0 {
825 if !yield(fantasy.ObjectStreamPart{
826 Type: fantasy.ObjectStreamPartTypeObject,
827 Warnings: warnings,
828 }) {
829 return
830 }
831 }
832
833 var accumulated string
834 var lastParsedObject any
835 var usage fantasy.Usage
836 var finishReason fantasy.FinishReason
837 var providerMetadata fantasy.ProviderMetadata
838 var streamErr error
839
840 for stream.Next() {
841 chunk := stream.Current()
842
843 // Update usage
844 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
845
846 if len(chunk.Choices) == 0 {
847 continue
848 }
849
850 choice := chunk.Choices[0]
851 if choice.FinishReason != "" {
852 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
853 }
854
855 if choice.Delta.Content != "" {
856 accumulated += choice.Delta.Content
857
858 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
859
860 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
861 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
862 if !reflect.DeepEqual(obj, lastParsedObject) {
863 if !yield(fantasy.ObjectStreamPart{
864 Type: fantasy.ObjectStreamPartTypeObject,
865 Object: obj,
866 }) {
867 return
868 }
869 lastParsedObject = obj
870 }
871 }
872 }
873
874 if state == schema.ParseStateFailed && call.RepairText != nil {
875 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
876 if repairErr == nil {
877 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
878 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
879 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
880 if !reflect.DeepEqual(obj2, lastParsedObject) {
881 if !yield(fantasy.ObjectStreamPart{
882 Type: fantasy.ObjectStreamPartTypeObject,
883 Object: obj2,
884 }) {
885 return
886 }
887 lastParsedObject = obj2
888 }
889 }
890 }
891 }
892 }
893 }
894
895 err := stream.Err()
896 if err != nil && !errors.Is(err, io.EOF) {
897 streamErr = toProviderErr(err)
898 yield(fantasy.ObjectStreamPart{
899 Type: fantasy.ObjectStreamPartTypeError,
900 Error: streamErr,
901 })
902 return
903 }
904
905 if lastParsedObject != nil {
906 yield(fantasy.ObjectStreamPart{
907 Type: fantasy.ObjectStreamPartTypeFinish,
908 Usage: usage,
909 FinishReason: finishReason,
910 ProviderMetadata: providerMetadata,
911 })
912 } else {
913 yield(fantasy.ObjectStreamPart{
914 Type: fantasy.ObjectStreamPartTypeError,
915 Error: &fantasy.NoObjectGeneratedError{
916 RawText: accumulated,
917 ParseError: fmt.Errorf("no valid object generated in stream"),
918 Usage: usage,
919 FinishReason: finishReason,
920 },
921 })
922 }
923 }, nil
924}
925
926// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
927// This is required by OpenAI's strict mode for structured outputs.
928func addAdditionalPropertiesFalse(schema map[string]any) {
929 if schema["type"] == "object" {
930 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
931 schema["additionalProperties"] = false
932 }
933
934 // Recursively process nested properties
935 if properties, ok := schema["properties"].(map[string]any); ok {
936 for _, propValue := range properties {
937 if propSchema, ok := propValue.(map[string]any); ok {
938 addAdditionalPropertiesFalse(propSchema)
939 }
940 }
941 }
942 }
943
944 // Handle array items
945 if items, ok := schema["items"].(map[string]any); ok {
946 addAdditionalPropertiesFalse(items)
947 }
948}