language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"reflect"
 10	"strings"
 11
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/object"
 14	"charm.land/fantasy/schema"
 15	xjson "github.com/charmbracelet/x/json"
 16	"github.com/google/uuid"
 17	"github.com/openai/openai-go/v2"
 18	"github.com/openai/openai-go/v2/packages/param"
 19	"github.com/openai/openai-go/v2/shared"
 20)
 21
 22type languageModel struct {
 23	provider                   string
 24	modelID                    string
 25	client                     openai.Client
 26	objectMode                 fantasy.ObjectMode
 27	prepareCallFunc            LanguageModelPrepareCallFunc
 28	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 29	extraContentFunc           LanguageModelExtraContentFunc
 30	usageFunc                  LanguageModelUsageFunc
 31	streamUsageFunc            LanguageModelStreamUsageFunc
 32	streamExtraFunc            LanguageModelStreamExtraFunc
 33	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 34	toPromptFunc               LanguageModelToPromptFunc
 35}
 36
 37// LanguageModelOption is a function that configures a languageModel.
 38type LanguageModelOption = func(*languageModel)
 39
 40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 42	return func(l *languageModel) {
 43		l.prepareCallFunc = fn
 44	}
 45}
 46
 47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 49	return func(l *languageModel) {
 50		l.mapFinishReasonFunc = fn
 51	}
 52}
 53
 54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
 55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 56	return func(l *languageModel) {
 57		l.extraContentFunc = fn
 58	}
 59}
 60
 61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
 62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 63	return func(l *languageModel) {
 64		l.streamExtraFunc = fn
 65	}
 66}
 67
 68// WithLanguageModelUsageFunc sets the usage function for the language model.
 69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 70	return func(l *languageModel) {
 71		l.usageFunc = fn
 72	}
 73}
 74
 75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
 76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 77	return func(l *languageModel) {
 78		l.streamUsageFunc = fn
 79	}
 80}
 81
 82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 84	return func(l *languageModel) {
 85		l.toPromptFunc = fn
 86	}
 87}
 88
 89// WithLanguageModelObjectMode sets the object generation mode.
 90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
 91	return func(l *languageModel) {
 92		// not supported
 93		if om == fantasy.ObjectModeJSON {
 94			om = fantasy.ObjectModeAuto
 95		}
 96		l.objectMode = om
 97	}
 98}
 99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101	model := languageModel{
102		modelID:                    modelID,
103		provider:                   provider,
104		client:                     client,
105		objectMode:                 fantasy.ObjectModeAuto,
106		prepareCallFunc:            DefaultPrepareCallFunc,
107		mapFinishReasonFunc:        DefaultMapFinishReasonFunc,
108		usageFunc:                  DefaultUsageFunc,
109		streamUsageFunc:            DefaultStreamUsageFunc,
110		streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111		toPromptFunc:               DefaultToPrompt,
112	}
113
114	for _, o := range opts {
115		o(&model)
116	}
117	return model
118}
119
120type streamToolCall struct {
121	id          string
122	name        string
123	arguments   string
124	hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129	return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134	return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138	params := &openai.ChatCompletionNewParams{}
139	messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140	if call.TopK != nil {
141		warnings = append(warnings, fantasy.CallWarning{
142			Type:    fantasy.CallWarningTypeUnsupportedSetting,
143			Setting: "top_k",
144		})
145	}
146
147	if call.MaxOutputTokens != nil {
148		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149	}
150	if call.Temperature != nil {
151		params.Temperature = param.NewOpt(*call.Temperature)
152	}
153	if call.TopP != nil {
154		params.TopP = param.NewOpt(*call.TopP)
155	}
156	if call.FrequencyPenalty != nil {
157		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158	}
159	if call.PresencePenalty != nil {
160		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161	}
162
163	if isReasoningModel(o.modelID) {
164		// remove unsupported settings for reasoning models
165		// see https://platform.openai.com/docs/guides/reasoning#limitations
166		if call.Temperature != nil {
167			params.Temperature = param.Opt[float64]{}
168			warnings = append(warnings, fantasy.CallWarning{
169				Type:    fantasy.CallWarningTypeUnsupportedSetting,
170				Setting: "temperature",
171				Details: "temperature is not supported for reasoning models",
172			})
173		}
174		if call.TopP != nil {
175			params.TopP = param.Opt[float64]{}
176			warnings = append(warnings, fantasy.CallWarning{
177				Type:    fantasy.CallWarningTypeUnsupportedSetting,
178				Setting: "TopP",
179				Details: "TopP is not supported for reasoning models",
180			})
181		}
182		if call.FrequencyPenalty != nil {
183			params.FrequencyPenalty = param.Opt[float64]{}
184			warnings = append(warnings, fantasy.CallWarning{
185				Type:    fantasy.CallWarningTypeUnsupportedSetting,
186				Setting: "FrequencyPenalty",
187				Details: "FrequencyPenalty is not supported for reasoning models",
188			})
189		}
190		if call.PresencePenalty != nil {
191			params.PresencePenalty = param.Opt[float64]{}
192			warnings = append(warnings, fantasy.CallWarning{
193				Type:    fantasy.CallWarningTypeUnsupportedSetting,
194				Setting: "PresencePenalty",
195				Details: "PresencePenalty is not supported for reasoning models",
196			})
197		}
198
199		// reasoning models use max_completion_tokens instead of max_tokens
200		if call.MaxOutputTokens != nil {
201			if !params.MaxCompletionTokens.Valid() {
202				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203			}
204			params.MaxTokens = param.Opt[int64]{}
205		}
206	}
207
208	// Handle search preview models
209	if isSearchPreviewModel(o.modelID) {
210		if call.Temperature != nil {
211			params.Temperature = param.Opt[float64]{}
212			warnings = append(warnings, fantasy.CallWarning{
213				Type:    fantasy.CallWarningTypeUnsupportedSetting,
214				Setting: "temperature",
215				Details: "temperature is not supported for the search preview models and has been removed.",
216			})
217		}
218	}
219
220	optionsWarnings, err := o.prepareCallFunc(o, params, call)
221	if err != nil {
222		return nil, nil, err
223	}
224
225	if len(optionsWarnings) > 0 {
226		warnings = append(warnings, optionsWarnings...)
227	}
228
229	params.Messages = messages
230	params.Model = o.modelID
231
232	if len(call.Tools) > 0 {
233		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234		params.Tools = tools
235		if toolChoice != nil {
236			params.ToolChoice = *toolChoice
237		}
238		warnings = append(warnings, toolWarnings...)
239	}
240	return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245	params, warnings, err := o.prepareParams(call)
246	if err != nil {
247		return nil, err
248	}
249	response, err := o.client.Chat.Completions.New(ctx, *params)
250	if err != nil {
251		return nil, toProviderErr(err)
252	}
253
254	if len(response.Choices) == 0 {
255		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256	}
257	choice := response.Choices[0]
258	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259	text := choice.Message.Content
260	if text != "" {
261		content = append(content, fantasy.TextContent{
262			Text: text,
263		})
264	}
265	if o.extraContentFunc != nil {
266		extraContent := o.extraContentFunc(choice)
267		content = append(content, extraContent...)
268	}
269	for _, tc := range choice.Message.ToolCalls {
270		toolCallID := tc.ID
271		content = append(content, fantasy.ToolCallContent{
272			ProviderExecuted: false,
273			ToolCallID:       toolCallID,
274			ToolName:         tc.Function.Name,
275			Input:            tc.Function.Arguments,
276		})
277	}
278	for _, annotation := range choice.Message.Annotations {
279		if annotation.Type == "url_citation" {
280			content = append(content, fantasy.SourceContent{
281				SourceType: fantasy.SourceTypeURL,
282				ID:         uuid.NewString(),
283				URL:        annotation.URLCitation.URL,
284				Title:      annotation.URLCitation.Title,
285			})
286		}
287	}
288
289	usage, providerMetadata := o.usageFunc(*response)
290
291	mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292	if len(choice.Message.ToolCalls) > 0 {
293		mappedFinishReason = fantasy.FinishReasonToolCalls
294	}
295	return &fantasy.Response{
296		Content:      content,
297		Usage:        usage,
298		FinishReason: mappedFinishReason,
299		ProviderMetadata: fantasy.ProviderMetadata{
300			Name: providerMetadata,
301		},
302		Warnings: warnings,
303	}, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308	params, warnings, err := o.prepareParams(call)
309	if err != nil {
310		return nil, err
311	}
312
313	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314		IncludeUsage: openai.Bool(true),
315	}
316
317	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318	isActiveText := false
319	toolCalls := make(map[int64]streamToolCall)
320
321	providerMetadata := fantasy.ProviderMetadata{
322		Name: &ProviderMetadata{},
323	}
324	acc := openai.ChatCompletionAccumulator{}
325	extraContext := make(map[string]any)
326	var usage fantasy.Usage
327	var finishReason string
328	return func(yield func(fantasy.StreamPart) bool) {
329		if len(warnings) > 0 {
330			if !yield(fantasy.StreamPart{
331				Type:     fantasy.StreamPartTypeWarnings,
332				Warnings: warnings,
333			}) {
334				return
335			}
336		}
337		for stream.Next() {
338			chunk := stream.Current()
339			acc.AddChunk(chunk)
340			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341			if len(chunk.Choices) == 0 {
342				continue
343			}
344			for _, choice := range chunk.Choices {
345				if choice.FinishReason != "" {
346					finishReason = choice.FinishReason
347				}
348				switch {
349				case choice.Delta.Content != "":
350					if !isActiveText {
351						isActiveText = true
352						if !yield(fantasy.StreamPart{
353							Type: fantasy.StreamPartTypeTextStart,
354							ID:   "0",
355						}) {
356							return
357						}
358					}
359					if !yield(fantasy.StreamPart{
360						Type:  fantasy.StreamPartTypeTextDelta,
361						ID:    "0",
362						Delta: choice.Delta.Content,
363					}) {
364						return
365					}
366				case len(choice.Delta.ToolCalls) > 0:
367					if isActiveText {
368						isActiveText = false
369						if !yield(fantasy.StreamPart{
370							Type: fantasy.StreamPartTypeTextEnd,
371							ID:   "0",
372						}) {
373							return
374						}
375					}
376
377					for _, toolCallDelta := range choice.Delta.ToolCalls {
378						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379							if existingToolCall.hasFinished {
380								continue
381							}
382							if toolCallDelta.Function.Arguments != "" {
383								existingToolCall.arguments += toolCallDelta.Function.Arguments
384							}
385							if !yield(fantasy.StreamPart{
386								Type:  fantasy.StreamPartTypeToolInputDelta,
387								ID:    existingToolCall.id,
388								Delta: toolCallDelta.Function.Arguments,
389							}) {
390								return
391							}
392							toolCalls[toolCallDelta.Index] = existingToolCall
393							if xjson.IsValid(existingToolCall.arguments) {
394								if !yield(fantasy.StreamPart{
395									Type: fantasy.StreamPartTypeToolInputEnd,
396									ID:   existingToolCall.id,
397								}) {
398									return
399								}
400
401								if !yield(fantasy.StreamPart{
402									Type:          fantasy.StreamPartTypeToolCall,
403									ID:            existingToolCall.id,
404									ToolCallName:  existingToolCall.name,
405									ToolCallInput: existingToolCall.arguments,
406								}) {
407									return
408								}
409								existingToolCall.hasFinished = true
410								toolCalls[toolCallDelta.Index] = existingToolCall
411							}
412						} else {
413							var err error
414							if toolCallDelta.Type != "function" {
415								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416							}
417							if toolCallDelta.ID == "" {
418								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419							}
420							if toolCallDelta.Function.Name == "" {
421								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422							}
423							if err != nil {
424								yield(fantasy.StreamPart{
425									Type:  fantasy.StreamPartTypeError,
426									Error: toProviderErr(stream.Err()),
427								})
428								return
429							}
430
431							if !yield(fantasy.StreamPart{
432								Type:         fantasy.StreamPartTypeToolInputStart,
433								ID:           toolCallDelta.ID,
434								ToolCallName: toolCallDelta.Function.Name,
435							}) {
436								return
437							}
438							toolCalls[toolCallDelta.Index] = streamToolCall{
439								id:        toolCallDelta.ID,
440								name:      toolCallDelta.Function.Name,
441								arguments: toolCallDelta.Function.Arguments,
442							}
443
444							exTc := toolCalls[toolCallDelta.Index]
445							if exTc.arguments != "" {
446								if !yield(fantasy.StreamPart{
447									Type:  fantasy.StreamPartTypeToolInputDelta,
448									ID:    exTc.id,
449									Delta: exTc.arguments,
450								}) {
451									return
452								}
453								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454									if !yield(fantasy.StreamPart{
455										Type: fantasy.StreamPartTypeToolInputEnd,
456										ID:   toolCallDelta.ID,
457									}) {
458										return
459									}
460
461									if !yield(fantasy.StreamPart{
462										Type:          fantasy.StreamPartTypeToolCall,
463										ID:            exTc.id,
464										ToolCallName:  exTc.name,
465										ToolCallInput: exTc.arguments,
466									}) {
467										return
468									}
469									exTc.hasFinished = true
470									toolCalls[toolCallDelta.Index] = exTc
471								}
472							}
473							continue
474						}
475					}
476				}
477
478				if o.streamExtraFunc != nil {
479					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480					if !shouldContinue {
481						return
482					}
483					extraContext = updatedContext
484				}
485			}
486
487			for _, choice := range chunk.Choices {
488				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489					for _, annotation := range annotations {
490						if annotation.Type == "url_citation" {
491							if !yield(fantasy.StreamPart{
492								Type:       fantasy.StreamPartTypeSource,
493								ID:         uuid.NewString(),
494								SourceType: fantasy.SourceTypeURL,
495								URL:        annotation.URLCitation.URL,
496								Title:      annotation.URLCitation.Title,
497							}) {
498								return
499							}
500						}
501					}
502				}
503			}
504		}
505		err := stream.Err()
506		if err == nil || errors.Is(err, io.EOF) {
507			if isActiveText {
508				isActiveText = false
509				if !yield(fantasy.StreamPart{
510					Type: fantasy.StreamPartTypeTextEnd,
511					ID:   "0",
512				}) {
513					return
514				}
515			}
516
517			if len(acc.Choices) > 0 {
518				choice := acc.Choices[0]
519				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521				for _, annotation := range choice.Message.Annotations {
522					if annotation.Type == "url_citation" {
523						if !yield(fantasy.StreamPart{
524							Type:       fantasy.StreamPartTypeSource,
525							ID:         acc.ID,
526							SourceType: fantasy.SourceTypeURL,
527							URL:        annotation.URLCitation.URL,
528							Title:      annotation.URLCitation.Title,
529						}) {
530							return
531						}
532					}
533				}
534			}
535			mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536			if len(acc.Choices) > 0 {
537				choice := acc.Choices[0]
538				if len(choice.Message.ToolCalls) > 0 {
539					mappedFinishReason = fantasy.FinishReasonToolCalls
540				}
541			}
542			yield(fantasy.StreamPart{
543				Type:             fantasy.StreamPartTypeFinish,
544				Usage:            usage,
545				FinishReason:     mappedFinishReason,
546				ProviderMetadata: providerMetadata,
547			})
548			return
549		} else { //nolint: revive
550			yield(fantasy.StreamPart{
551				Type:  fantasy.StreamPartTypeError,
552				Error: toProviderErr(err),
553			})
554			return
555		}
556	}, nil
557}
558
559func isReasoningModel(modelID string) bool {
560	return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561		strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562		strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563		strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564		strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568	return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572	return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573		strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577	return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578		strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579		strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583	for _, tool := range tools {
584		if tool.GetType() == fantasy.ToolTypeFunction {
585			ft, ok := tool.(fantasy.FunctionTool)
586			if !ok {
587				continue
588			}
589			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590				OfFunction: &openai.ChatCompletionFunctionToolParam{
591					Function: shared.FunctionDefinitionParam{
592						Name:        ft.Name,
593						Description: param.NewOpt(ft.Description),
594						Parameters:  openai.FunctionParameters(ft.InputSchema),
595						Strict:      param.NewOpt(false),
596					},
597					Type: "function",
598				},
599			})
600			continue
601		}
602
603		warnings = append(warnings, fantasy.CallWarning{
604			Type:    fantasy.CallWarningTypeUnsupportedTool,
605			Tool:    tool,
606			Message: "tool is not supported",
607		})
608	}
609	if toolChoice == nil {
610		return openAiTools, openAiToolChoice, warnings
611	}
612
613	switch *toolChoice {
614	case fantasy.ToolChoiceAuto:
615		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616			OfAuto: param.NewOpt("auto"),
617		}
618	case fantasy.ToolChoiceNone:
619		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620			OfAuto: param.NewOpt("none"),
621		}
622	case fantasy.ToolChoiceRequired:
623		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624			OfAuto: param.NewOpt("required"),
625		}
626	default:
627		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
628			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
629				Type: "function",
630				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
631					Name: string(*toolChoice),
632				},
633			},
634		}
635	}
636	return openAiTools, openAiToolChoice, warnings
637}
638
639// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
640func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
641	var annotations []openai.ChatCompletionMessageAnnotation
642
643	// Parse the raw JSON to extract annotations
644	var deltaData map[string]any
645	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
646		return annotations
647	}
648
649	// Check if annotations exist in the delta
650	if annotationsData, ok := deltaData["annotations"].([]any); ok {
651		for _, annotationData := range annotationsData {
652			if annotationMap, ok := annotationData.(map[string]any); ok {
653				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
654					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
655						url, urlOk := urlCitationData["url"].(string)
656						title, titleOk := urlCitationData["title"].(string)
657						if urlOk && titleOk {
658							annotation := openai.ChatCompletionMessageAnnotation{
659								Type: "url_citation",
660								URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
661									URL:   url,
662									Title: title,
663								},
664							}
665							annotations = append(annotations, annotation)
666						}
667					}
668				}
669			}
670		}
671	}
672
673	return annotations
674}
675
676// GenerateObject implements fantasy.LanguageModel.
677func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
678	switch o.objectMode {
679	case fantasy.ObjectModeText:
680		return object.GenerateWithText(ctx, o, call)
681	case fantasy.ObjectModeTool:
682		return object.GenerateWithTool(ctx, o, call)
683	default:
684		return o.generateObjectWithJSONMode(ctx, call)
685	}
686}
687
688// StreamObject implements fantasy.LanguageModel.
689func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
690	switch o.objectMode {
691	case fantasy.ObjectModeTool:
692		return object.StreamWithTool(ctx, o, call)
693	case fantasy.ObjectModeText:
694		return object.StreamWithText(ctx, o, call)
695	default:
696		return o.streamObjectWithJSONMode(ctx, call)
697	}
698}
699
700func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
701	jsonSchemaMap := schema.ToMap(call.Schema)
702
703	addAdditionalPropertiesFalse(jsonSchemaMap)
704
705	schemaName := call.SchemaName
706	if schemaName == "" {
707		schemaName = "response"
708	}
709
710	fantasyCall := fantasy.Call{
711		Prompt:           call.Prompt,
712		MaxOutputTokens:  call.MaxOutputTokens,
713		Temperature:      call.Temperature,
714		TopP:             call.TopP,
715		PresencePenalty:  call.PresencePenalty,
716		FrequencyPenalty: call.FrequencyPenalty,
717		ProviderOptions:  call.ProviderOptions,
718	}
719
720	params, warnings, err := o.prepareParams(fantasyCall)
721	if err != nil {
722		return nil, err
723	}
724
725	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
726		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
727			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
728				Name:        schemaName,
729				Description: param.NewOpt(call.SchemaDescription),
730				Schema:      jsonSchemaMap,
731				Strict:      param.NewOpt(true),
732			},
733		},
734	}
735
736	response, err := o.client.Chat.Completions.New(ctx, *params)
737	if err != nil {
738		return nil, toProviderErr(err)
739	}
740
741	if len(response.Choices) == 0 {
742		usage, _ := o.usageFunc(*response)
743		return nil, &fantasy.NoObjectGeneratedError{
744			RawText:      "",
745			ParseError:   fmt.Errorf("no choices in response"),
746			Usage:        usage,
747			FinishReason: fantasy.FinishReasonUnknown,
748		}
749	}
750
751	choice := response.Choices[0]
752	jsonText := choice.Message.Content
753
754	var obj any
755	if call.RepairText != nil {
756		obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
757	} else {
758		obj, err = schema.ParseAndValidate(jsonText, call.Schema)
759	}
760
761	usage, _ := o.usageFunc(*response)
762	finishReason := o.mapFinishReasonFunc(choice.FinishReason)
763
764	if err != nil {
765		if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
766			nogErr.Usage = usage
767			nogErr.FinishReason = finishReason
768		}
769		return nil, err
770	}
771
772	return &fantasy.ObjectResponse{
773		Object:       obj,
774		RawText:      jsonText,
775		Usage:        usage,
776		FinishReason: finishReason,
777		Warnings:     warnings,
778	}, nil
779}
780
781func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
782	jsonSchemaMap := schema.ToMap(call.Schema)
783
784	addAdditionalPropertiesFalse(jsonSchemaMap)
785
786	schemaName := call.SchemaName
787	if schemaName == "" {
788		schemaName = "response"
789	}
790
791	fantasyCall := fantasy.Call{
792		Prompt:           call.Prompt,
793		MaxOutputTokens:  call.MaxOutputTokens,
794		Temperature:      call.Temperature,
795		TopP:             call.TopP,
796		PresencePenalty:  call.PresencePenalty,
797		FrequencyPenalty: call.FrequencyPenalty,
798		ProviderOptions:  call.ProviderOptions,
799	}
800
801	params, warnings, err := o.prepareParams(fantasyCall)
802	if err != nil {
803		return nil, err
804	}
805
806	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
807		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
808			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
809				Name:        schemaName,
810				Description: param.NewOpt(call.SchemaDescription),
811				Schema:      jsonSchemaMap,
812				Strict:      param.NewOpt(true),
813			},
814		},
815	}
816
817	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
818		IncludeUsage: openai.Bool(true),
819	}
820
821	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
822
823	return func(yield func(fantasy.ObjectStreamPart) bool) {
824		if len(warnings) > 0 {
825			if !yield(fantasy.ObjectStreamPart{
826				Type:     fantasy.ObjectStreamPartTypeObject,
827				Warnings: warnings,
828			}) {
829				return
830			}
831		}
832
833		var accumulated string
834		var lastParsedObject any
835		var usage fantasy.Usage
836		var finishReason fantasy.FinishReason
837		var providerMetadata fantasy.ProviderMetadata
838		var streamErr error
839
840		for stream.Next() {
841			chunk := stream.Current()
842
843			// Update usage
844			usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
845
846			if len(chunk.Choices) == 0 {
847				continue
848			}
849
850			choice := chunk.Choices[0]
851			if choice.FinishReason != "" {
852				finishReason = o.mapFinishReasonFunc(choice.FinishReason)
853			}
854
855			if choice.Delta.Content != "" {
856				accumulated += choice.Delta.Content
857
858				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
859
860				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
861					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
862						if !reflect.DeepEqual(obj, lastParsedObject) {
863							if !yield(fantasy.ObjectStreamPart{
864								Type:   fantasy.ObjectStreamPartTypeObject,
865								Object: obj,
866							}) {
867								return
868							}
869							lastParsedObject = obj
870						}
871					}
872				}
873
874				if state == schema.ParseStateFailed && call.RepairText != nil {
875					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
876					if repairErr == nil {
877						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
878						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
879							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
880							if !reflect.DeepEqual(obj2, lastParsedObject) {
881								if !yield(fantasy.ObjectStreamPart{
882									Type:   fantasy.ObjectStreamPartTypeObject,
883									Object: obj2,
884								}) {
885									return
886								}
887								lastParsedObject = obj2
888							}
889						}
890					}
891				}
892			}
893		}
894
895		err := stream.Err()
896		if err != nil && !errors.Is(err, io.EOF) {
897			streamErr = toProviderErr(err)
898			yield(fantasy.ObjectStreamPart{
899				Type:  fantasy.ObjectStreamPartTypeError,
900				Error: streamErr,
901			})
902			return
903		}
904
905		if lastParsedObject != nil {
906			yield(fantasy.ObjectStreamPart{
907				Type:             fantasy.ObjectStreamPartTypeFinish,
908				Usage:            usage,
909				FinishReason:     finishReason,
910				ProviderMetadata: providerMetadata,
911			})
912		} else {
913			yield(fantasy.ObjectStreamPart{
914				Type: fantasy.ObjectStreamPartTypeError,
915				Error: &fantasy.NoObjectGeneratedError{
916					RawText:      accumulated,
917					ParseError:   fmt.Errorf("no valid object generated in stream"),
918					Usage:        usage,
919					FinishReason: finishReason,
920				},
921			})
922		}
923	}, nil
924}
925
926// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
927// This is required by OpenAI's strict mode for structured outputs.
928func addAdditionalPropertiesFalse(schema map[string]any) {
929	if schema["type"] == "object" {
930		if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
931			schema["additionalProperties"] = false
932		}
933
934		// Recursively process nested properties
935		if properties, ok := schema["properties"].(map[string]any); ok {
936			for _, propValue := range properties {
937				if propSchema, ok := propValue.(map[string]any); ok {
938					addAdditionalPropertiesFalse(propSchema)
939				}
940			}
941		}
942	}
943
944	// Handle array items
945	if items, ok := schema["items"].(map[string]any); ok {
946		addAdditionalPropertiesFalse(items)
947	}
948}