language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"reflect"
 10	"strings"
 11
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/object"
 14	"charm.land/fantasy/schema"
 15	xjson "github.com/charmbracelet/x/json"
 16	"github.com/google/uuid"
 17	"github.com/openai/openai-go/v2"
 18	"github.com/openai/openai-go/v2/packages/param"
 19	"github.com/openai/openai-go/v2/shared"
 20)
 21
 22type languageModel struct {
 23	provider                   string
 24	modelID                    string
 25	client                     openai.Client
 26	objectMode                 fantasy.ObjectMode
 27	prepareCallFunc            LanguageModelPrepareCallFunc
 28	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 29	extraContentFunc           LanguageModelExtraContentFunc
 30	usageFunc                  LanguageModelUsageFunc
 31	streamUsageFunc            LanguageModelStreamUsageFunc
 32	streamExtraFunc            LanguageModelStreamExtraFunc
 33	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 34	toPromptFunc               LanguageModelToPromptFunc
 35}
 36
 37// LanguageModelOption is a function that configures a languageModel.
 38type LanguageModelOption = func(*languageModel)
 39
 40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 42	return func(l *languageModel) {
 43		l.prepareCallFunc = fn
 44	}
 45}
 46
 47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 49	return func(l *languageModel) {
 50		l.mapFinishReasonFunc = fn
 51	}
 52}
 53
 54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
 55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 56	return func(l *languageModel) {
 57		l.extraContentFunc = fn
 58	}
 59}
 60
 61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
 62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 63	return func(l *languageModel) {
 64		l.streamExtraFunc = fn
 65	}
 66}
 67
 68// WithLanguageModelUsageFunc sets the usage function for the language model.
 69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 70	return func(l *languageModel) {
 71		l.usageFunc = fn
 72	}
 73}
 74
 75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
 76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 77	return func(l *languageModel) {
 78		l.streamUsageFunc = fn
 79	}
 80}
 81
 82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 84	return func(l *languageModel) {
 85		l.toPromptFunc = fn
 86	}
 87}
 88
 89// WithLanguageModelObjectMode sets the object generation mode.
 90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
 91	return func(l *languageModel) {
 92		// not supported
 93		if om == fantasy.ObjectModeJSON {
 94			om = fantasy.ObjectModeAuto
 95		}
 96		l.objectMode = om
 97	}
 98}
 99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101	model := languageModel{
102		modelID:                    modelID,
103		provider:                   provider,
104		client:                     client,
105		objectMode:                 fantasy.ObjectModeAuto,
106		prepareCallFunc:            DefaultPrepareCallFunc,
107		mapFinishReasonFunc:        DefaultMapFinishReasonFunc,
108		usageFunc:                  DefaultUsageFunc,
109		streamUsageFunc:            DefaultStreamUsageFunc,
110		streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111		toPromptFunc:               DefaultToPrompt,
112	}
113
114	for _, o := range opts {
115		o(&model)
116	}
117	return model
118}
119
120type streamToolCall struct {
121	id          string
122	name        string
123	arguments   string
124	hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129	return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134	return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138	params := &openai.ChatCompletionNewParams{}
139	messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140	if call.TopK != nil {
141		warnings = append(warnings, fantasy.CallWarning{
142			Type:    fantasy.CallWarningTypeUnsupportedSetting,
143			Setting: "top_k",
144		})
145	}
146
147	if call.MaxOutputTokens != nil {
148		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149	}
150	if call.Temperature != nil {
151		params.Temperature = param.NewOpt(*call.Temperature)
152	}
153	if call.TopP != nil {
154		params.TopP = param.NewOpt(*call.TopP)
155	}
156	if call.FrequencyPenalty != nil {
157		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158	}
159	if call.PresencePenalty != nil {
160		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161	}
162
163	if isReasoningModel(o.modelID) {
164		// remove unsupported settings for reasoning models
165		// see https://platform.openai.com/docs/guides/reasoning#limitations
166		if call.Temperature != nil {
167			params.Temperature = param.Opt[float64]{}
168			warnings = append(warnings, fantasy.CallWarning{
169				Type:    fantasy.CallWarningTypeUnsupportedSetting,
170				Setting: "temperature",
171				Details: "temperature is not supported for reasoning models",
172			})
173		}
174		if call.TopP != nil {
175			params.TopP = param.Opt[float64]{}
176			warnings = append(warnings, fantasy.CallWarning{
177				Type:    fantasy.CallWarningTypeUnsupportedSetting,
178				Setting: "TopP",
179				Details: "TopP is not supported for reasoning models",
180			})
181		}
182		if call.FrequencyPenalty != nil {
183			params.FrequencyPenalty = param.Opt[float64]{}
184			warnings = append(warnings, fantasy.CallWarning{
185				Type:    fantasy.CallWarningTypeUnsupportedSetting,
186				Setting: "FrequencyPenalty",
187				Details: "FrequencyPenalty is not supported for reasoning models",
188			})
189		}
190		if call.PresencePenalty != nil {
191			params.PresencePenalty = param.Opt[float64]{}
192			warnings = append(warnings, fantasy.CallWarning{
193				Type:    fantasy.CallWarningTypeUnsupportedSetting,
194				Setting: "PresencePenalty",
195				Details: "PresencePenalty is not supported for reasoning models",
196			})
197		}
198
199		// reasoning models use max_completion_tokens instead of max_tokens
200		if call.MaxOutputTokens != nil {
201			if !params.MaxCompletionTokens.Valid() {
202				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203			}
204			params.MaxTokens = param.Opt[int64]{}
205		}
206	}
207
208	// Handle search preview models
209	if isSearchPreviewModel(o.modelID) {
210		if call.Temperature != nil {
211			params.Temperature = param.Opt[float64]{}
212			warnings = append(warnings, fantasy.CallWarning{
213				Type:    fantasy.CallWarningTypeUnsupportedSetting,
214				Setting: "temperature",
215				Details: "temperature is not supported for the search preview models and has been removed.",
216			})
217		}
218	}
219
220	optionsWarnings, err := o.prepareCallFunc(o, params, call)
221	if err != nil {
222		return nil, nil, err
223	}
224
225	if len(optionsWarnings) > 0 {
226		warnings = append(warnings, optionsWarnings...)
227	}
228
229	params.Messages = messages
230	params.Model = o.modelID
231
232	if len(call.Tools) > 0 {
233		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234		params.Tools = tools
235		if toolChoice != nil {
236			params.ToolChoice = *toolChoice
237		}
238		warnings = append(warnings, toolWarnings...)
239	}
240	return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245	params, warnings, err := o.prepareParams(call)
246	if err != nil {
247		return nil, err
248	}
249	response, err := o.client.Chat.Completions.New(ctx, *params, callUARequestOptions(call)...)
250	if err != nil {
251		return nil, toProviderErr(err)
252	}
253
254	if len(response.Choices) == 0 {
255		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256	}
257	choice := response.Choices[0]
258	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259	text := choice.Message.Content
260	if text != "" {
261		content = append(content, fantasy.TextContent{
262			Text: text,
263		})
264	}
265	if o.extraContentFunc != nil {
266		extraContent := o.extraContentFunc(choice)
267		content = append(content, extraContent...)
268	}
269	for _, tc := range choice.Message.ToolCalls {
270		toolCallID := tc.ID
271		content = append(content, fantasy.ToolCallContent{
272			ProviderExecuted: false,
273			ToolCallID:       toolCallID,
274			ToolName:         tc.Function.Name,
275			Input:            tc.Function.Arguments,
276		})
277	}
278	for _, annotation := range choice.Message.Annotations {
279		if annotation.Type == "url_citation" {
280			content = append(content, fantasy.SourceContent{
281				SourceType: fantasy.SourceTypeURL,
282				ID:         uuid.NewString(),
283				URL:        annotation.URLCitation.URL,
284				Title:      annotation.URLCitation.Title,
285			})
286		}
287	}
288
289	usage, providerMetadata := o.usageFunc(*response)
290
291	mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292	if len(choice.Message.ToolCalls) > 0 {
293		mappedFinishReason = fantasy.FinishReasonToolCalls
294	}
295	return &fantasy.Response{
296		Content:      content,
297		Usage:        usage,
298		FinishReason: mappedFinishReason,
299		ProviderMetadata: fantasy.ProviderMetadata{
300			Name: providerMetadata,
301		},
302		Warnings: warnings,
303	}, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308	params, warnings, err := o.prepareParams(call)
309	if err != nil {
310		return nil, err
311	}
312
313	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314		IncludeUsage: openai.Bool(true),
315	}
316
317	stream := o.client.Chat.Completions.NewStreaming(ctx, *params, callUARequestOptions(call)...)
318	isActiveText := false
319	toolCalls := make(map[int64]streamToolCall)
320
321	providerMetadata := fantasy.ProviderMetadata{
322		Name: &ProviderMetadata{},
323	}
324	acc := openai.ChatCompletionAccumulator{}
325	extraContext := make(map[string]any)
326	var usage fantasy.Usage
327	var finishReason string
328	return func(yield func(fantasy.StreamPart) bool) {
329		if len(warnings) > 0 {
330			if !yield(fantasy.StreamPart{
331				Type:     fantasy.StreamPartTypeWarnings,
332				Warnings: warnings,
333			}) {
334				return
335			}
336		}
337		for stream.Next() {
338			chunk := stream.Current()
339			acc.AddChunk(chunk)
340			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341			if len(chunk.Choices) == 0 {
342				continue
343			}
344			for _, choice := range chunk.Choices {
345				if choice.FinishReason != "" {
346					finishReason = choice.FinishReason
347				}
348				switch {
349				case choice.Delta.Content != "":
350					if !isActiveText {
351						isActiveText = true
352						if !yield(fantasy.StreamPart{
353							Type: fantasy.StreamPartTypeTextStart,
354							ID:   "0",
355						}) {
356							return
357						}
358					}
359					if !yield(fantasy.StreamPart{
360						Type:  fantasy.StreamPartTypeTextDelta,
361						ID:    "0",
362						Delta: choice.Delta.Content,
363					}) {
364						return
365					}
366				case len(choice.Delta.ToolCalls) > 0:
367					if isActiveText {
368						isActiveText = false
369						if !yield(fantasy.StreamPart{
370							Type: fantasy.StreamPartTypeTextEnd,
371							ID:   "0",
372						}) {
373							return
374						}
375					}
376
377					for _, toolCallDelta := range choice.Delta.ToolCalls {
378						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379							if existingToolCall.hasFinished {
380								continue
381							}
382							if toolCallDelta.Function.Arguments != "" {
383								existingToolCall.arguments += toolCallDelta.Function.Arguments
384							}
385							if !yield(fantasy.StreamPart{
386								Type:  fantasy.StreamPartTypeToolInputDelta,
387								ID:    existingToolCall.id,
388								Delta: toolCallDelta.Function.Arguments,
389							}) {
390								return
391							}
392							toolCalls[toolCallDelta.Index] = existingToolCall
393							if xjson.IsValid(existingToolCall.arguments) {
394								if !yield(fantasy.StreamPart{
395									Type: fantasy.StreamPartTypeToolInputEnd,
396									ID:   existingToolCall.id,
397								}) {
398									return
399								}
400
401								if !yield(fantasy.StreamPart{
402									Type:          fantasy.StreamPartTypeToolCall,
403									ID:            existingToolCall.id,
404									ToolCallName:  existingToolCall.name,
405									ToolCallInput: existingToolCall.arguments,
406								}) {
407									return
408								}
409								existingToolCall.hasFinished = true
410								toolCalls[toolCallDelta.Index] = existingToolCall
411							}
412						} else {
413							var err error
414							if toolCallDelta.Type != "function" {
415								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416							}
417							if toolCallDelta.ID == "" {
418								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419							}
420							if toolCallDelta.Function.Name == "" {
421								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422							}
423							if err != nil {
424								yield(fantasy.StreamPart{
425									Type:  fantasy.StreamPartTypeError,
426									Error: toProviderErr(stream.Err()),
427								})
428								return
429							}
430
431							if !yield(fantasy.StreamPart{
432								Type:         fantasy.StreamPartTypeToolInputStart,
433								ID:           toolCallDelta.ID,
434								ToolCallName: toolCallDelta.Function.Name,
435							}) {
436								return
437							}
438							toolCalls[toolCallDelta.Index] = streamToolCall{
439								id:        toolCallDelta.ID,
440								name:      toolCallDelta.Function.Name,
441								arguments: toolCallDelta.Function.Arguments,
442							}
443
444							exTc := toolCalls[toolCallDelta.Index]
445							if exTc.arguments != "" {
446								if !yield(fantasy.StreamPart{
447									Type:  fantasy.StreamPartTypeToolInputDelta,
448									ID:    exTc.id,
449									Delta: exTc.arguments,
450								}) {
451									return
452								}
453								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454									if !yield(fantasy.StreamPart{
455										Type: fantasy.StreamPartTypeToolInputEnd,
456										ID:   toolCallDelta.ID,
457									}) {
458										return
459									}
460
461									if !yield(fantasy.StreamPart{
462										Type:          fantasy.StreamPartTypeToolCall,
463										ID:            exTc.id,
464										ToolCallName:  exTc.name,
465										ToolCallInput: exTc.arguments,
466									}) {
467										return
468									}
469									exTc.hasFinished = true
470									toolCalls[toolCallDelta.Index] = exTc
471								}
472							}
473							continue
474						}
475					}
476				}
477
478				if o.streamExtraFunc != nil {
479					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480					if !shouldContinue {
481						return
482					}
483					extraContext = updatedContext
484				}
485			}
486
487			for _, choice := range chunk.Choices {
488				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489					for _, annotation := range annotations {
490						if annotation.Type == "url_citation" {
491							if !yield(fantasy.StreamPart{
492								Type:       fantasy.StreamPartTypeSource,
493								ID:         uuid.NewString(),
494								SourceType: fantasy.SourceTypeURL,
495								URL:        annotation.URLCitation.URL,
496								Title:      annotation.URLCitation.Title,
497							}) {
498								return
499							}
500						}
501					}
502				}
503			}
504		}
505		err := stream.Err()
506		if err == nil || errors.Is(err, io.EOF) {
507			if isActiveText {
508				isActiveText = false
509				if !yield(fantasy.StreamPart{
510					Type: fantasy.StreamPartTypeTextEnd,
511					ID:   "0",
512				}) {
513					return
514				}
515			}
516
517			if len(acc.Choices) > 0 {
518				choice := acc.Choices[0]
519				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521				for _, annotation := range choice.Message.Annotations {
522					if annotation.Type == "url_citation" {
523						if !yield(fantasy.StreamPart{
524							Type:       fantasy.StreamPartTypeSource,
525							ID:         acc.ID,
526							SourceType: fantasy.SourceTypeURL,
527							URL:        annotation.URLCitation.URL,
528							Title:      annotation.URLCitation.Title,
529						}) {
530							return
531						}
532					}
533				}
534			}
535			mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536			if len(acc.Choices) > 0 {
537				choice := acc.Choices[0]
538				if len(choice.Message.ToolCalls) > 0 {
539					mappedFinishReason = fantasy.FinishReasonToolCalls
540				}
541			}
542			yield(fantasy.StreamPart{
543				Type:             fantasy.StreamPartTypeFinish,
544				Usage:            usage,
545				FinishReason:     mappedFinishReason,
546				ProviderMetadata: providerMetadata,
547			})
548			return
549		} else { //nolint: revive
550			yield(fantasy.StreamPart{
551				Type:  fantasy.StreamPartTypeError,
552				Error: toProviderErr(err),
553			})
554			return
555		}
556	}, nil
557}
558
559func isReasoningModel(modelID string) bool {
560	return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561		strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562		strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563		strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564		strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568	return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572	return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573		strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577	return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578		strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579		strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583	for _, tool := range tools {
584		if tool.GetType() == fantasy.ToolTypeFunction {
585			ft, ok := tool.(fantasy.FunctionTool)
586			if !ok {
587				continue
588			}
589			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590				OfFunction: &openai.ChatCompletionFunctionToolParam{
591					Function: shared.FunctionDefinitionParam{
592						Name:        ft.Name,
593						Description: param.NewOpt(ft.Description),
594						Parameters:  openai.FunctionParameters(ft.InputSchema),
595						Strict:      param.NewOpt(false),
596					},
597					Type: "function",
598				},
599			})
600			continue
601		}
602
603		warnings = append(warnings, fantasy.CallWarning{
604			Type:    fantasy.CallWarningTypeUnsupportedTool,
605			Tool:    tool,
606			Message: "tool is not supported",
607		})
608	}
609	if toolChoice == nil {
610		return openAiTools, openAiToolChoice, warnings
611	}
612
613	switch *toolChoice {
614	case fantasy.ToolChoiceAuto:
615		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616			OfAuto: param.NewOpt("auto"),
617		}
618	case fantasy.ToolChoiceNone:
619		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620			OfAuto: param.NewOpt("none"),
621		}
622	case fantasy.ToolChoiceRequired:
623		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624			OfAuto: param.NewOpt("required"),
625		}
626	default:
627		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
628			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
629				Type: "function",
630				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
631					Name: string(*toolChoice),
632				},
633			},
634		}
635	}
636	return openAiTools, openAiToolChoice, warnings
637}
638
639// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
640func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
641	var annotations []openai.ChatCompletionMessageAnnotation
642
643	// Parse the raw JSON to extract annotations
644	var deltaData map[string]any
645	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
646		return annotations
647	}
648
649	// Check if annotations exist in the delta
650	if annotationsData, ok := deltaData["annotations"].([]any); ok {
651		for _, annotationData := range annotationsData {
652			if annotationMap, ok := annotationData.(map[string]any); ok {
653				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
654					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
655						url, urlOk := urlCitationData["url"].(string)
656						title, titleOk := urlCitationData["title"].(string)
657						if urlOk && titleOk {
658							annotation := openai.ChatCompletionMessageAnnotation{
659								Type: "url_citation",
660								URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
661									URL:   url,
662									Title: title,
663								},
664							}
665							annotations = append(annotations, annotation)
666						}
667					}
668				}
669			}
670		}
671	}
672
673	return annotations
674}
675
676// GenerateObject implements fantasy.LanguageModel.
677func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
678	switch o.objectMode {
679	case fantasy.ObjectModeText:
680		return object.GenerateWithText(ctx, o, call)
681	case fantasy.ObjectModeTool:
682		return object.GenerateWithTool(ctx, o, call)
683	default:
684		return o.generateObjectWithJSONMode(ctx, call)
685	}
686}
687
688// StreamObject implements fantasy.LanguageModel.
689func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
690	switch o.objectMode {
691	case fantasy.ObjectModeTool:
692		return object.StreamWithTool(ctx, o, call)
693	case fantasy.ObjectModeText:
694		return object.StreamWithText(ctx, o, call)
695	default:
696		return o.streamObjectWithJSONMode(ctx, call)
697	}
698}
699
700func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
701	jsonSchemaMap := schema.ToMap(call.Schema)
702
703	addAdditionalPropertiesFalse(jsonSchemaMap)
704
705	schemaName := call.SchemaName
706	if schemaName == "" {
707		schemaName = "response"
708	}
709
710	fantasyCall := fantasy.Call{
711		Prompt:           call.Prompt,
712		MaxOutputTokens:  call.MaxOutputTokens,
713		Temperature:      call.Temperature,
714		TopP:             call.TopP,
715		PresencePenalty:  call.PresencePenalty,
716		FrequencyPenalty: call.FrequencyPenalty,
717		ProviderOptions:  call.ProviderOptions,
718	}
719
720	params, warnings, err := o.prepareParams(fantasyCall)
721	if err != nil {
722		return nil, err
723	}
724
725	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
726		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
727			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
728				Name:        schemaName,
729				Description: param.NewOpt(call.SchemaDescription),
730				Schema:      jsonSchemaMap,
731				Strict:      param.NewOpt(true),
732			},
733		},
734	}
735
736	response, err := o.client.Chat.Completions.New(ctx, *params, objectCallUARequestOptions(call)...)
737	if err != nil {
738		return nil, toProviderErr(err)
739	}
740	if len(response.Choices) == 0 {
741		usage, _ := o.usageFunc(*response)
742		return nil, &fantasy.NoObjectGeneratedError{
743			RawText:      "",
744			ParseError:   fmt.Errorf("no choices in response"),
745			Usage:        usage,
746			FinishReason: fantasy.FinishReasonUnknown,
747		}
748	}
749
750	choice := response.Choices[0]
751	jsonText := choice.Message.Content
752
753	var obj any
754	if call.RepairText != nil {
755		obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
756	} else {
757		obj, err = schema.ParseAndValidate(jsonText, call.Schema)
758	}
759
760	usage, _ := o.usageFunc(*response)
761	finishReason := o.mapFinishReasonFunc(choice.FinishReason)
762
763	if err != nil {
764		if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
765			nogErr.Usage = usage
766			nogErr.FinishReason = finishReason
767		}
768		return nil, err
769	}
770
771	return &fantasy.ObjectResponse{
772		Object:       obj,
773		RawText:      jsonText,
774		Usage:        usage,
775		FinishReason: finishReason,
776		Warnings:     warnings,
777	}, nil
778}
779
780func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
781	jsonSchemaMap := schema.ToMap(call.Schema)
782
783	addAdditionalPropertiesFalse(jsonSchemaMap)
784
785	schemaName := call.SchemaName
786	if schemaName == "" {
787		schemaName = "response"
788	}
789
790	fantasyCall := fantasy.Call{
791		Prompt:           call.Prompt,
792		MaxOutputTokens:  call.MaxOutputTokens,
793		Temperature:      call.Temperature,
794		TopP:             call.TopP,
795		PresencePenalty:  call.PresencePenalty,
796		FrequencyPenalty: call.FrequencyPenalty,
797		ProviderOptions:  call.ProviderOptions,
798	}
799
800	params, warnings, err := o.prepareParams(fantasyCall)
801	if err != nil {
802		return nil, err
803	}
804
805	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
806		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
807			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
808				Name:        schemaName,
809				Description: param.NewOpt(call.SchemaDescription),
810				Schema:      jsonSchemaMap,
811				Strict:      param.NewOpt(true),
812			},
813		},
814	}
815
816	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
817		IncludeUsage: openai.Bool(true),
818	}
819
820	stream := o.client.Chat.Completions.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...)
821
822	return func(yield func(fantasy.ObjectStreamPart) bool) {
823		if len(warnings) > 0 {
824			if !yield(fantasy.ObjectStreamPart{
825				Type:     fantasy.ObjectStreamPartTypeObject,
826				Warnings: warnings,
827			}) {
828				return
829			}
830		}
831
832		var accumulated string
833		var lastParsedObject any
834		var usage fantasy.Usage
835		var finishReason fantasy.FinishReason
836		var providerMetadata fantasy.ProviderMetadata
837		var streamErr error
838
839		for stream.Next() {
840			chunk := stream.Current()
841
842			// Update usage
843			usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
844
845			if len(chunk.Choices) == 0 {
846				continue
847			}
848
849			choice := chunk.Choices[0]
850			if choice.FinishReason != "" {
851				finishReason = o.mapFinishReasonFunc(choice.FinishReason)
852			}
853
854			if choice.Delta.Content != "" {
855				accumulated += choice.Delta.Content
856
857				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
858
859				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
860					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
861						if !reflect.DeepEqual(obj, lastParsedObject) {
862							if !yield(fantasy.ObjectStreamPart{
863								Type:   fantasy.ObjectStreamPartTypeObject,
864								Object: obj,
865							}) {
866								return
867							}
868							lastParsedObject = obj
869						}
870					}
871				}
872
873				if state == schema.ParseStateFailed && call.RepairText != nil {
874					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
875					if repairErr == nil {
876						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
877						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
878							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
879							if !reflect.DeepEqual(obj2, lastParsedObject) {
880								if !yield(fantasy.ObjectStreamPart{
881									Type:   fantasy.ObjectStreamPartTypeObject,
882									Object: obj2,
883								}) {
884									return
885								}
886								lastParsedObject = obj2
887							}
888						}
889					}
890				}
891			}
892		}
893
894		err := stream.Err()
895		if err != nil && !errors.Is(err, io.EOF) {
896			streamErr = toProviderErr(err)
897			yield(fantasy.ObjectStreamPart{
898				Type:  fantasy.ObjectStreamPartTypeError,
899				Error: streamErr,
900			})
901			return
902		}
903
904		if lastParsedObject != nil {
905			yield(fantasy.ObjectStreamPart{
906				Type:             fantasy.ObjectStreamPartTypeFinish,
907				Usage:            usage,
908				FinishReason:     finishReason,
909				ProviderMetadata: providerMetadata,
910			})
911		} else {
912			yield(fantasy.ObjectStreamPart{
913				Type: fantasy.ObjectStreamPartTypeError,
914				Error: &fantasy.NoObjectGeneratedError{
915					RawText:      accumulated,
916					ParseError:   fmt.Errorf("no valid object generated in stream"),
917					Usage:        usage,
918					FinishReason: finishReason,
919				},
920			})
921		}
922	}, nil
923}
924
925// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
926// This is required by OpenAI's strict mode for structured outputs.
927func addAdditionalPropertiesFalse(schema map[string]any) {
928	if schema["type"] == "object" {
929		if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
930			schema["additionalProperties"] = false
931		}
932
933		// Recursively process nested properties
934		if properties, ok := schema["properties"].(map[string]any); ok {
935			for _, propValue := range properties {
936				if propSchema, ok := propValue.(map[string]any); ok {
937					addAdditionalPropertiesFalse(propSchema)
938				}
939			}
940		}
941	}
942
943	// Handle array items
944	if items, ok := schema["items"].(map[string]any); ok {
945		addAdditionalPropertiesFalse(items)
946	}
947}