language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"reflect"
 10	"strings"
 11
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/object"
 14	"charm.land/fantasy/schema"
 15	xjson "github.com/charmbracelet/x/json"
 16	"github.com/google/uuid"
 17	"github.com/openai/openai-go/v2"
 18	"github.com/openai/openai-go/v2/packages/param"
 19	"github.com/openai/openai-go/v2/shared"
 20)
 21
 22type languageModel struct {
 23	provider                   string
 24	modelID                    string
 25	client                     openai.Client
 26	objectMode                 fantasy.ObjectMode
 27	prepareCallFunc            LanguageModelPrepareCallFunc
 28	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 29	extraContentFunc           LanguageModelExtraContentFunc
 30	usageFunc                  LanguageModelUsageFunc
 31	streamUsageFunc            LanguageModelStreamUsageFunc
 32	streamExtraFunc            LanguageModelStreamExtraFunc
 33	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 34	toPromptFunc               LanguageModelToPromptFunc
 35}
 36
 37// LanguageModelOption is a function that configures a languageModel.
 38type LanguageModelOption = func(*languageModel)
 39
 40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 42	return func(l *languageModel) {
 43		l.prepareCallFunc = fn
 44	}
 45}
 46
 47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 49	return func(l *languageModel) {
 50		l.mapFinishReasonFunc = fn
 51	}
 52}
 53
 54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
 55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 56	return func(l *languageModel) {
 57		l.extraContentFunc = fn
 58	}
 59}
 60
 61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
 62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 63	return func(l *languageModel) {
 64		l.streamExtraFunc = fn
 65	}
 66}
 67
 68// WithLanguageModelUsageFunc sets the usage function for the language model.
 69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 70	return func(l *languageModel) {
 71		l.usageFunc = fn
 72	}
 73}
 74
 75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
 76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 77	return func(l *languageModel) {
 78		l.streamUsageFunc = fn
 79	}
 80}
 81
 82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 84	return func(l *languageModel) {
 85		l.toPromptFunc = fn
 86	}
 87}
 88
 89// WithLanguageModelObjectMode sets the object generation mode.
 90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
 91	return func(l *languageModel) {
 92		// not supported
 93		if om == fantasy.ObjectModeJSON {
 94			om = fantasy.ObjectModeAuto
 95		}
 96		l.objectMode = om
 97	}
 98}
 99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101	model := languageModel{
102		modelID:                    modelID,
103		provider:                   provider,
104		client:                     client,
105		objectMode:                 fantasy.ObjectModeAuto,
106		prepareCallFunc:            DefaultPrepareCallFunc,
107		mapFinishReasonFunc:        DefaultMapFinishReasonFunc,
108		usageFunc:                  DefaultUsageFunc,
109		streamUsageFunc:            DefaultStreamUsageFunc,
110		streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111		toPromptFunc:               DefaultToPrompt,
112	}
113
114	for _, o := range opts {
115		o(&model)
116	}
117	return model
118}
119
120type streamToolCall struct {
121	id          string
122	name        string
123	arguments   string
124	hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129	return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134	return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138	params := &openai.ChatCompletionNewParams{}
139	messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140	if call.TopK != nil {
141		warnings = append(warnings, fantasy.CallWarning{
142			Type:    fantasy.CallWarningTypeUnsupportedSetting,
143			Setting: "top_k",
144		})
145	}
146
147	if call.MaxOutputTokens != nil {
148		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149	}
150	if call.Temperature != nil {
151		params.Temperature = param.NewOpt(*call.Temperature)
152	}
153	if call.TopP != nil {
154		params.TopP = param.NewOpt(*call.TopP)
155	}
156	if call.FrequencyPenalty != nil {
157		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158	}
159	if call.PresencePenalty != nil {
160		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161	}
162
163	if isReasoningModel(o.modelID) {
164		// remove unsupported settings for reasoning models
165		// see https://platform.openai.com/docs/guides/reasoning#limitations
166		if call.Temperature != nil {
167			params.Temperature = param.Opt[float64]{}
168			warnings = append(warnings, fantasy.CallWarning{
169				Type:    fantasy.CallWarningTypeUnsupportedSetting,
170				Setting: "temperature",
171				Details: "temperature is not supported for reasoning models",
172			})
173		}
174		if call.TopP != nil {
175			params.TopP = param.Opt[float64]{}
176			warnings = append(warnings, fantasy.CallWarning{
177				Type:    fantasy.CallWarningTypeUnsupportedSetting,
178				Setting: "TopP",
179				Details: "TopP is not supported for reasoning models",
180			})
181		}
182		if call.FrequencyPenalty != nil {
183			params.FrequencyPenalty = param.Opt[float64]{}
184			warnings = append(warnings, fantasy.CallWarning{
185				Type:    fantasy.CallWarningTypeUnsupportedSetting,
186				Setting: "FrequencyPenalty",
187				Details: "FrequencyPenalty is not supported for reasoning models",
188			})
189		}
190		if call.PresencePenalty != nil {
191			params.PresencePenalty = param.Opt[float64]{}
192			warnings = append(warnings, fantasy.CallWarning{
193				Type:    fantasy.CallWarningTypeUnsupportedSetting,
194				Setting: "PresencePenalty",
195				Details: "PresencePenalty is not supported for reasoning models",
196			})
197		}
198
199		// reasoning models use max_completion_tokens instead of max_tokens
200		if call.MaxOutputTokens != nil {
201			if !params.MaxCompletionTokens.Valid() {
202				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203			}
204			params.MaxTokens = param.Opt[int64]{}
205		}
206	}
207
208	// Handle search preview models
209	if isSearchPreviewModel(o.modelID) {
210		if call.Temperature != nil {
211			params.Temperature = param.Opt[float64]{}
212			warnings = append(warnings, fantasy.CallWarning{
213				Type:    fantasy.CallWarningTypeUnsupportedSetting,
214				Setting: "temperature",
215				Details: "temperature is not supported for the search preview models and has been removed.",
216			})
217		}
218	}
219
220	optionsWarnings, err := o.prepareCallFunc(o, params, call)
221	if err != nil {
222		return nil, nil, err
223	}
224
225	if len(optionsWarnings) > 0 {
226		warnings = append(warnings, optionsWarnings...)
227	}
228
229	params.Messages = messages
230	params.Model = o.modelID
231
232	if len(call.Tools) > 0 {
233		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234		params.Tools = tools
235		if toolChoice != nil {
236			params.ToolChoice = *toolChoice
237		}
238		warnings = append(warnings, toolWarnings...)
239	}
240	return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245	params, warnings, err := o.prepareParams(call)
246	if err != nil {
247		return nil, err
248	}
249	response, err := o.client.Chat.Completions.New(ctx, *params)
250	if err != nil {
251		return nil, toProviderErr(err)
252	}
253
254	if len(response.Choices) == 0 {
255		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256	}
257	choice := response.Choices[0]
258	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259	text := choice.Message.Content
260	if text != "" {
261		content = append(content, fantasy.TextContent{
262			Text: text,
263		})
264	}
265	if o.extraContentFunc != nil {
266		extraContent := o.extraContentFunc(choice)
267		content = append(content, extraContent...)
268	}
269	for _, tc := range choice.Message.ToolCalls {
270		toolCallID := tc.ID
271		content = append(content, fantasy.ToolCallContent{
272			ProviderExecuted: false,
273			ToolCallID:       toolCallID,
274			ToolName:         tc.Function.Name,
275			Input:            tc.Function.Arguments,
276		})
277	}
278	for _, annotation := range choice.Message.Annotations {
279		if annotation.Type == "url_citation" {
280			content = append(content, fantasy.SourceContent{
281				SourceType: fantasy.SourceTypeURL,
282				ID:         uuid.NewString(),
283				URL:        annotation.URLCitation.URL,
284				Title:      annotation.URLCitation.Title,
285			})
286		}
287	}
288
289	usage, providerMetadata := o.usageFunc(*response)
290
291	mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292	if len(choice.Message.ToolCalls) > 0 {
293		mappedFinishReason = fantasy.FinishReasonToolCalls
294	}
295	return &fantasy.Response{
296		Content:      content,
297		Usage:        usage,
298		FinishReason: mappedFinishReason,
299		ProviderMetadata: fantasy.ProviderMetadata{
300			Name: providerMetadata,
301		},
302		Warnings: warnings,
303	}, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308	params, warnings, err := o.prepareParams(call)
309	if err != nil {
310		return nil, err
311	}
312
313	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314		IncludeUsage: openai.Bool(true),
315	}
316
317	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318	isActiveText := false
319	toolCalls := make(map[int64]streamToolCall)
320
321	providerMetadata := fantasy.ProviderMetadata{
322		Name: &ProviderMetadata{},
323	}
324	acc := openai.ChatCompletionAccumulator{}
325	extraContext := make(map[string]any)
326	var usage fantasy.Usage
327	var finishReason string
328	return func(yield func(fantasy.StreamPart) bool) {
329		if len(warnings) > 0 {
330			if !yield(fantasy.StreamPart{
331				Type:     fantasy.StreamPartTypeWarnings,
332				Warnings: warnings,
333			}) {
334				return
335			}
336		}
337		for stream.Next() {
338			chunk := stream.Current()
339			acc.AddChunk(chunk)
340			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341			if len(chunk.Choices) == 0 {
342				continue
343			}
344			for _, choice := range chunk.Choices {
345				if choice.FinishReason != "" {
346					finishReason = choice.FinishReason
347				}
348				switch {
349				case choice.Delta.Content != "":
350					if !isActiveText {
351						isActiveText = true
352						if !yield(fantasy.StreamPart{
353							Type: fantasy.StreamPartTypeTextStart,
354							ID:   "0",
355						}) {
356							return
357						}
358					}
359					if !yield(fantasy.StreamPart{
360						Type:  fantasy.StreamPartTypeTextDelta,
361						ID:    "0",
362						Delta: choice.Delta.Content,
363					}) {
364						return
365					}
366				case len(choice.Delta.ToolCalls) > 0:
367					if isActiveText {
368						isActiveText = false
369						if !yield(fantasy.StreamPart{
370							Type: fantasy.StreamPartTypeTextEnd,
371							ID:   "0",
372						}) {
373							return
374						}
375					}
376
377					for _, toolCallDelta := range choice.Delta.ToolCalls {
378						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379							if existingToolCall.hasFinished {
380								continue
381							}
382							if toolCallDelta.Function.Arguments != "" {
383								existingToolCall.arguments += toolCallDelta.Function.Arguments
384							}
385							if !yield(fantasy.StreamPart{
386								Type:  fantasy.StreamPartTypeToolInputDelta,
387								ID:    existingToolCall.id,
388								Delta: toolCallDelta.Function.Arguments,
389							}) {
390								return
391							}
392							toolCalls[toolCallDelta.Index] = existingToolCall
393							if xjson.IsValid(existingToolCall.arguments) {
394								if !yield(fantasy.StreamPart{
395									Type: fantasy.StreamPartTypeToolInputEnd,
396									ID:   existingToolCall.id,
397								}) {
398									return
399								}
400
401								if !yield(fantasy.StreamPart{
402									Type:          fantasy.StreamPartTypeToolCall,
403									ID:            existingToolCall.id,
404									ToolCallName:  existingToolCall.name,
405									ToolCallInput: existingToolCall.arguments,
406								}) {
407									return
408								}
409								existingToolCall.hasFinished = true
410								toolCalls[toolCallDelta.Index] = existingToolCall
411							}
412						} else {
413							var err error
414							if toolCallDelta.Type != "function" {
415								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416							}
417							if toolCallDelta.ID == "" {
418								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419							}
420							if toolCallDelta.Function.Name == "" {
421								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422							}
423							if err != nil {
424								yield(fantasy.StreamPart{
425									Type:  fantasy.StreamPartTypeError,
426									Error: toProviderErr(stream.Err()),
427								})
428								return
429							}
430
431							if !yield(fantasy.StreamPart{
432								Type:         fantasy.StreamPartTypeToolInputStart,
433								ID:           toolCallDelta.ID,
434								ToolCallName: toolCallDelta.Function.Name,
435							}) {
436								return
437							}
438							toolCalls[toolCallDelta.Index] = streamToolCall{
439								id:        toolCallDelta.ID,
440								name:      toolCallDelta.Function.Name,
441								arguments: toolCallDelta.Function.Arguments,
442							}
443
444							exTc := toolCalls[toolCallDelta.Index]
445							if exTc.arguments != "" {
446								if !yield(fantasy.StreamPart{
447									Type:  fantasy.StreamPartTypeToolInputDelta,
448									ID:    exTc.id,
449									Delta: exTc.arguments,
450								}) {
451									return
452								}
453								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454									if !yield(fantasy.StreamPart{
455										Type: fantasy.StreamPartTypeToolInputEnd,
456										ID:   toolCallDelta.ID,
457									}) {
458										return
459									}
460
461									if !yield(fantasy.StreamPart{
462										Type:          fantasy.StreamPartTypeToolCall,
463										ID:            exTc.id,
464										ToolCallName:  exTc.name,
465										ToolCallInput: exTc.arguments,
466									}) {
467										return
468									}
469									exTc.hasFinished = true
470									toolCalls[toolCallDelta.Index] = exTc
471								}
472							}
473							continue
474						}
475					}
476				}
477
478				if o.streamExtraFunc != nil {
479					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480					if !shouldContinue {
481						return
482					}
483					extraContext = updatedContext
484				}
485			}
486
487			for _, choice := range chunk.Choices {
488				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489					for _, annotation := range annotations {
490						if annotation.Type == "url_citation" {
491							if !yield(fantasy.StreamPart{
492								Type:       fantasy.StreamPartTypeSource,
493								ID:         uuid.NewString(),
494								SourceType: fantasy.SourceTypeURL,
495								URL:        annotation.URLCitation.URL,
496								Title:      annotation.URLCitation.Title,
497							}) {
498								return
499							}
500						}
501					}
502				}
503			}
504		}
505		err := stream.Err()
506		if err == nil || errors.Is(err, io.EOF) {
507			if isActiveText {
508				isActiveText = false
509				if !yield(fantasy.StreamPart{
510					Type: fantasy.StreamPartTypeTextEnd,
511					ID:   "0",
512				}) {
513					return
514				}
515			}
516
517			if len(acc.Choices) > 0 {
518				choice := acc.Choices[0]
519				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521				for _, annotation := range choice.Message.Annotations {
522					if annotation.Type == "url_citation" {
523						if !yield(fantasy.StreamPart{
524							Type:       fantasy.StreamPartTypeSource,
525							ID:         acc.ID,
526							SourceType: fantasy.SourceTypeURL,
527							URL:        annotation.URLCitation.URL,
528							Title:      annotation.URLCitation.Title,
529						}) {
530							return
531						}
532					}
533				}
534			}
535			mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536			if len(acc.Choices) > 0 {
537				choice := acc.Choices[0]
538				if len(choice.Message.ToolCalls) > 0 {
539					mappedFinishReason = fantasy.FinishReasonToolCalls
540				}
541			}
542			yield(fantasy.StreamPart{
543				Type:             fantasy.StreamPartTypeFinish,
544				Usage:            usage,
545				FinishReason:     mappedFinishReason,
546				ProviderMetadata: providerMetadata,
547			})
548			return
549		} else { //nolint: revive
550			yield(fantasy.StreamPart{
551				Type:  fantasy.StreamPartTypeError,
552				Error: toProviderErr(err),
553			})
554			return
555		}
556	}, nil
557}
558
559func isReasoningModel(modelID string) bool {
560	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
561}
562
563func isSearchPreviewModel(modelID string) bool {
564	return strings.Contains(modelID, "search-preview")
565}
566
567func supportsFlexProcessing(modelID string) bool {
568	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
569}
570
571func supportsPriorityProcessing(modelID string) bool {
572	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
573		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
574		strings.HasPrefix(modelID, "o4-mini")
575}
576
577func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
578	for _, tool := range tools {
579		if tool.GetType() == fantasy.ToolTypeFunction {
580			ft, ok := tool.(fantasy.FunctionTool)
581			if !ok {
582				continue
583			}
584			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
585				OfFunction: &openai.ChatCompletionFunctionToolParam{
586					Function: shared.FunctionDefinitionParam{
587						Name:        ft.Name,
588						Description: param.NewOpt(ft.Description),
589						Parameters:  openai.FunctionParameters(ft.InputSchema),
590						Strict:      param.NewOpt(false),
591					},
592					Type: "function",
593				},
594			})
595			continue
596		}
597
598		warnings = append(warnings, fantasy.CallWarning{
599			Type:    fantasy.CallWarningTypeUnsupportedTool,
600			Tool:    tool,
601			Message: "tool is not supported",
602		})
603	}
604	if toolChoice == nil {
605		return openAiTools, openAiToolChoice, warnings
606	}
607
608	switch *toolChoice {
609	case fantasy.ToolChoiceAuto:
610		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
611			OfAuto: param.NewOpt("auto"),
612		}
613	case fantasy.ToolChoiceNone:
614		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
615			OfAuto: param.NewOpt("none"),
616		}
617	default:
618		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
619			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
620				Type: "function",
621				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
622					Name: string(*toolChoice),
623				},
624			},
625		}
626	}
627	return openAiTools, openAiToolChoice, warnings
628}
629
630// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
631func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
632	var annotations []openai.ChatCompletionMessageAnnotation
633
634	// Parse the raw JSON to extract annotations
635	var deltaData map[string]any
636	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
637		return annotations
638	}
639
640	// Check if annotations exist in the delta
641	if annotationsData, ok := deltaData["annotations"].([]any); ok {
642		for _, annotationData := range annotationsData {
643			if annotationMap, ok := annotationData.(map[string]any); ok {
644				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
645					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
646						annotation := openai.ChatCompletionMessageAnnotation{
647							Type: "url_citation",
648							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
649								URL:   urlCitationData["url"].(string),
650								Title: urlCitationData["title"].(string),
651							},
652						}
653						annotations = append(annotations, annotation)
654					}
655				}
656			}
657		}
658	}
659
660	return annotations
661}
662
663// GenerateObject implements fantasy.LanguageModel.
664func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
665	switch o.objectMode {
666	case fantasy.ObjectModeText:
667		return object.GenerateWithText(ctx, o, call)
668	case fantasy.ObjectModeTool:
669		return object.GenerateWithTool(ctx, o, call)
670	default:
671		return o.generateObjectWithJSONMode(ctx, call)
672	}
673}
674
675// StreamObject implements fantasy.LanguageModel.
676func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
677	switch o.objectMode {
678	case fantasy.ObjectModeTool:
679		return object.StreamWithTool(ctx, o, call)
680	case fantasy.ObjectModeText:
681		return object.StreamWithText(ctx, o, call)
682	default:
683		return o.streamObjectWithJSONMode(ctx, call)
684	}
685}
686
687func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
688	jsonSchemaMap := schema.ToMap(call.Schema)
689
690	addAdditionalPropertiesFalse(jsonSchemaMap)
691
692	schemaName := call.SchemaName
693	if schemaName == "" {
694		schemaName = "response"
695	}
696
697	fantasyCall := fantasy.Call{
698		Prompt:           call.Prompt,
699		MaxOutputTokens:  call.MaxOutputTokens,
700		Temperature:      call.Temperature,
701		TopP:             call.TopP,
702		PresencePenalty:  call.PresencePenalty,
703		FrequencyPenalty: call.FrequencyPenalty,
704		ProviderOptions:  call.ProviderOptions,
705	}
706
707	params, warnings, err := o.prepareParams(fantasyCall)
708	if err != nil {
709		return nil, err
710	}
711
712	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
713		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
714			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
715				Name:        schemaName,
716				Description: param.NewOpt(call.SchemaDescription),
717				Schema:      jsonSchemaMap,
718				Strict:      param.NewOpt(true),
719			},
720		},
721	}
722
723	response, err := o.client.Chat.Completions.New(ctx, *params)
724	if err != nil {
725		return nil, toProviderErr(err)
726	}
727
728	if len(response.Choices) == 0 {
729		usage, _ := o.usageFunc(*response)
730		return nil, &fantasy.NoObjectGeneratedError{
731			RawText:      "",
732			ParseError:   fmt.Errorf("no choices in response"),
733			Usage:        usage,
734			FinishReason: fantasy.FinishReasonUnknown,
735		}
736	}
737
738	choice := response.Choices[0]
739	jsonText := choice.Message.Content
740
741	var obj any
742	if call.RepairText != nil {
743		obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
744	} else {
745		obj, err = schema.ParseAndValidate(jsonText, call.Schema)
746	}
747
748	usage, _ := o.usageFunc(*response)
749	finishReason := o.mapFinishReasonFunc(choice.FinishReason)
750
751	if err != nil {
752		if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
753			nogErr.Usage = usage
754			nogErr.FinishReason = finishReason
755		}
756		return nil, err
757	}
758
759	return &fantasy.ObjectResponse{
760		Object:       obj,
761		RawText:      jsonText,
762		Usage:        usage,
763		FinishReason: finishReason,
764		Warnings:     warnings,
765	}, nil
766}
767
768func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
769	jsonSchemaMap := schema.ToMap(call.Schema)
770
771	addAdditionalPropertiesFalse(jsonSchemaMap)
772
773	schemaName := call.SchemaName
774	if schemaName == "" {
775		schemaName = "response"
776	}
777
778	fantasyCall := fantasy.Call{
779		Prompt:           call.Prompt,
780		MaxOutputTokens:  call.MaxOutputTokens,
781		Temperature:      call.Temperature,
782		TopP:             call.TopP,
783		PresencePenalty:  call.PresencePenalty,
784		FrequencyPenalty: call.FrequencyPenalty,
785		ProviderOptions:  call.ProviderOptions,
786	}
787
788	params, warnings, err := o.prepareParams(fantasyCall)
789	if err != nil {
790		return nil, err
791	}
792
793	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
794		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
795			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
796				Name:        schemaName,
797				Description: param.NewOpt(call.SchemaDescription),
798				Schema:      jsonSchemaMap,
799				Strict:      param.NewOpt(true),
800			},
801		},
802	}
803
804	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
805		IncludeUsage: openai.Bool(true),
806	}
807
808	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
809
810	return func(yield func(fantasy.ObjectStreamPart) bool) {
811		if len(warnings) > 0 {
812			if !yield(fantasy.ObjectStreamPart{
813				Type:     fantasy.ObjectStreamPartTypeObject,
814				Warnings: warnings,
815			}) {
816				return
817			}
818		}
819
820		var accumulated string
821		var lastParsedObject any
822		var usage fantasy.Usage
823		var finishReason fantasy.FinishReason
824		var providerMetadata fantasy.ProviderMetadata
825		var streamErr error
826
827		for stream.Next() {
828			chunk := stream.Current()
829
830			// Update usage
831			usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
832
833			if len(chunk.Choices) == 0 {
834				continue
835			}
836
837			choice := chunk.Choices[0]
838			if choice.FinishReason != "" {
839				finishReason = o.mapFinishReasonFunc(choice.FinishReason)
840			}
841
842			if choice.Delta.Content != "" {
843				accumulated += choice.Delta.Content
844
845				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
846
847				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
848					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
849						if !reflect.DeepEqual(obj, lastParsedObject) {
850							if !yield(fantasy.ObjectStreamPart{
851								Type:   fantasy.ObjectStreamPartTypeObject,
852								Object: obj,
853							}) {
854								return
855							}
856							lastParsedObject = obj
857						}
858					}
859				}
860
861				if state == schema.ParseStateFailed && call.RepairText != nil {
862					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
863					if repairErr == nil {
864						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
865						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
866							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
867							if !reflect.DeepEqual(obj2, lastParsedObject) {
868								if !yield(fantasy.ObjectStreamPart{
869									Type:   fantasy.ObjectStreamPartTypeObject,
870									Object: obj2,
871								}) {
872									return
873								}
874								lastParsedObject = obj2
875							}
876						}
877					}
878				}
879			}
880		}
881
882		err := stream.Err()
883		if err != nil && !errors.Is(err, io.EOF) {
884			streamErr = toProviderErr(err)
885			yield(fantasy.ObjectStreamPart{
886				Type:  fantasy.ObjectStreamPartTypeError,
887				Error: streamErr,
888			})
889			return
890		}
891
892		if lastParsedObject != nil {
893			yield(fantasy.ObjectStreamPart{
894				Type:             fantasy.ObjectStreamPartTypeFinish,
895				Usage:            usage,
896				FinishReason:     finishReason,
897				ProviderMetadata: providerMetadata,
898			})
899		} else {
900			yield(fantasy.ObjectStreamPart{
901				Type: fantasy.ObjectStreamPartTypeError,
902				Error: &fantasy.NoObjectGeneratedError{
903					RawText:      accumulated,
904					ParseError:   fmt.Errorf("no valid object generated in stream"),
905					Usage:        usage,
906					FinishReason: finishReason,
907				},
908			})
909		}
910	}, nil
911}
912
913// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
914// This is required by OpenAI's strict mode for structured outputs.
915func addAdditionalPropertiesFalse(schema map[string]any) {
916	if schema["type"] == "object" {
917		if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
918			schema["additionalProperties"] = false
919		}
920
921		// Recursively process nested properties
922		if properties, ok := schema["properties"].(map[string]any); ok {
923			for _, propValue := range properties {
924				if propSchema, ok := propValue.(map[string]any); ok {
925					addAdditionalPropertiesFalse(propSchema)
926				}
927			}
928		}
929	}
930
931	// Handle array items
932	if items, ok := schema["items"].(map[string]any); ok {
933		addAdditionalPropertiesFalse(items)
934	}
935}