language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"reflect"
 10	"strings"
 11
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/object"
 14	"charm.land/fantasy/schema"
 15	xjson "github.com/charmbracelet/x/json"
 16	"github.com/google/uuid"
 17	"github.com/openai/openai-go/v2"
 18	"github.com/openai/openai-go/v2/packages/param"
 19	"github.com/openai/openai-go/v2/shared"
 20)
 21
 22type languageModel struct {
 23	provider                   string
 24	modelID                    string
 25	client                     openai.Client
 26	objectMode                 fantasy.ObjectMode
 27	prepareCallFunc            LanguageModelPrepareCallFunc
 28	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 29	extraContentFunc           LanguageModelExtraContentFunc
 30	usageFunc                  LanguageModelUsageFunc
 31	streamUsageFunc            LanguageModelStreamUsageFunc
 32	streamExtraFunc            LanguageModelStreamExtraFunc
 33	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 34	toPromptFunc               LanguageModelToPromptFunc
 35}
 36
 37// LanguageModelOption is a function that configures a languageModel.
 38type LanguageModelOption = func(*languageModel)
 39
 40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 42	return func(l *languageModel) {
 43		l.prepareCallFunc = fn
 44	}
 45}
 46
 47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 49	return func(l *languageModel) {
 50		l.mapFinishReasonFunc = fn
 51	}
 52}
 53
 54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
 55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 56	return func(l *languageModel) {
 57		l.extraContentFunc = fn
 58	}
 59}
 60
 61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
 62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 63	return func(l *languageModel) {
 64		l.streamExtraFunc = fn
 65	}
 66}
 67
 68// WithLanguageModelUsageFunc sets the usage function for the language model.
 69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 70	return func(l *languageModel) {
 71		l.usageFunc = fn
 72	}
 73}
 74
 75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
 76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 77	return func(l *languageModel) {
 78		l.streamUsageFunc = fn
 79	}
 80}
 81
 82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 84	return func(l *languageModel) {
 85		l.toPromptFunc = fn
 86	}
 87}
 88
 89// WithLanguageModelObjectMode sets the object generation mode.
 90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
 91	return func(l *languageModel) {
 92		// not supported
 93		if om == fantasy.ObjectModeJSON {
 94			om = fantasy.ObjectModeAuto
 95		}
 96		l.objectMode = om
 97	}
 98}
 99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101	model := languageModel{
102		modelID:                    modelID,
103		provider:                   provider,
104		client:                     client,
105		objectMode:                 fantasy.ObjectModeAuto,
106		prepareCallFunc:            DefaultPrepareCallFunc,
107		mapFinishReasonFunc:        DefaultMapFinishReasonFunc,
108		usageFunc:                  DefaultUsageFunc,
109		streamUsageFunc:            DefaultStreamUsageFunc,
110		streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111		toPromptFunc:               DefaultToPrompt,
112	}
113
114	for _, o := range opts {
115		o(&model)
116	}
117	return model
118}
119
120type streamToolCall struct {
121	id          string
122	name        string
123	arguments   string
124	hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129	return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134	return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138	params := &openai.ChatCompletionNewParams{}
139	messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140	if call.TopK != nil {
141		warnings = append(warnings, fantasy.CallWarning{
142			Type:    fantasy.CallWarningTypeUnsupportedSetting,
143			Setting: "top_k",
144		})
145	}
146
147	if call.MaxOutputTokens != nil {
148		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149	}
150	if call.Temperature != nil {
151		params.Temperature = param.NewOpt(*call.Temperature)
152	}
153	if call.TopP != nil {
154		params.TopP = param.NewOpt(*call.TopP)
155	}
156	if call.FrequencyPenalty != nil {
157		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158	}
159	if call.PresencePenalty != nil {
160		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161	}
162
163	if isReasoningModel(o.modelID) {
164		// remove unsupported settings for reasoning models
165		// see https://platform.openai.com/docs/guides/reasoning#limitations
166		if call.Temperature != nil {
167			params.Temperature = param.Opt[float64]{}
168			warnings = append(warnings, fantasy.CallWarning{
169				Type:    fantasy.CallWarningTypeUnsupportedSetting,
170				Setting: "temperature",
171				Details: "temperature is not supported for reasoning models",
172			})
173		}
174		if call.TopP != nil {
175			params.TopP = param.Opt[float64]{}
176			warnings = append(warnings, fantasy.CallWarning{
177				Type:    fantasy.CallWarningTypeUnsupportedSetting,
178				Setting: "TopP",
179				Details: "TopP is not supported for reasoning models",
180			})
181		}
182		if call.FrequencyPenalty != nil {
183			params.FrequencyPenalty = param.Opt[float64]{}
184			warnings = append(warnings, fantasy.CallWarning{
185				Type:    fantasy.CallWarningTypeUnsupportedSetting,
186				Setting: "FrequencyPenalty",
187				Details: "FrequencyPenalty is not supported for reasoning models",
188			})
189		}
190		if call.PresencePenalty != nil {
191			params.PresencePenalty = param.Opt[float64]{}
192			warnings = append(warnings, fantasy.CallWarning{
193				Type:    fantasy.CallWarningTypeUnsupportedSetting,
194				Setting: "PresencePenalty",
195				Details: "PresencePenalty is not supported for reasoning models",
196			})
197		}
198
199		// reasoning models use max_completion_tokens instead of max_tokens
200		if call.MaxOutputTokens != nil {
201			if !params.MaxCompletionTokens.Valid() {
202				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203			}
204			params.MaxTokens = param.Opt[int64]{}
205		}
206	}
207
208	// Handle search preview models
209	if isSearchPreviewModel(o.modelID) {
210		if call.Temperature != nil {
211			params.Temperature = param.Opt[float64]{}
212			warnings = append(warnings, fantasy.CallWarning{
213				Type:    fantasy.CallWarningTypeUnsupportedSetting,
214				Setting: "temperature",
215				Details: "temperature is not supported for the search preview models and has been removed.",
216			})
217		}
218	}
219
220	optionsWarnings, err := o.prepareCallFunc(o, params, call)
221	if err != nil {
222		return nil, nil, err
223	}
224
225	if len(optionsWarnings) > 0 {
226		warnings = append(warnings, optionsWarnings...)
227	}
228
229	params.Messages = messages
230	params.Model = o.modelID
231
232	if len(call.Tools) > 0 {
233		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234		params.Tools = tools
235		if toolChoice != nil {
236			params.ToolChoice = *toolChoice
237		}
238		warnings = append(warnings, toolWarnings...)
239	}
240	return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245	params, warnings, err := o.prepareParams(call)
246	if err != nil {
247		return nil, err
248	}
249	response, err := o.client.Chat.Completions.New(ctx, *params)
250	if err != nil {
251		return nil, toProviderErr(err)
252	}
253
254	if len(response.Choices) == 0 {
255		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256	}
257	choice := response.Choices[0]
258	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259	text := choice.Message.Content
260	if text != "" {
261		content = append(content, fantasy.TextContent{
262			Text: text,
263		})
264	}
265	if o.extraContentFunc != nil {
266		extraContent := o.extraContentFunc(choice)
267		content = append(content, extraContent...)
268	}
269	for _, tc := range choice.Message.ToolCalls {
270		toolCallID := tc.ID
271		content = append(content, fantasy.ToolCallContent{
272			ProviderExecuted: false,
273			ToolCallID:       toolCallID,
274			ToolName:         tc.Function.Name,
275			Input:            tc.Function.Arguments,
276		})
277	}
278	for _, annotation := range choice.Message.Annotations {
279		if annotation.Type == "url_citation" {
280			content = append(content, fantasy.SourceContent{
281				SourceType: fantasy.SourceTypeURL,
282				ID:         uuid.NewString(),
283				URL:        annotation.URLCitation.URL,
284				Title:      annotation.URLCitation.Title,
285			})
286		}
287	}
288
289	usage, providerMetadata := o.usageFunc(*response)
290
291	mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292	if len(choice.Message.ToolCalls) > 0 {
293		mappedFinishReason = fantasy.FinishReasonToolCalls
294	}
295	return &fantasy.Response{
296		Content:      content,
297		Usage:        usage,
298		FinishReason: mappedFinishReason,
299		ProviderMetadata: fantasy.ProviderMetadata{
300			Name: providerMetadata,
301		},
302		Warnings: warnings,
303	}, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308	params, warnings, err := o.prepareParams(call)
309	if err != nil {
310		return nil, err
311	}
312
313	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314		IncludeUsage: openai.Bool(true),
315	}
316
317	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318	isActiveText := false
319	toolCalls := make(map[int64]streamToolCall)
320
321	providerMetadata := fantasy.ProviderMetadata{
322		Name: &ProviderMetadata{},
323	}
324	acc := openai.ChatCompletionAccumulator{}
325	extraContext := make(map[string]any)
326	var usage fantasy.Usage
327	var finishReason string
328	return func(yield func(fantasy.StreamPart) bool) {
329		if len(warnings) > 0 {
330			if !yield(fantasy.StreamPart{
331				Type:     fantasy.StreamPartTypeWarnings,
332				Warnings: warnings,
333			}) {
334				return
335			}
336		}
337		for stream.Next() {
338			chunk := stream.Current()
339			acc.AddChunk(chunk)
340			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341			if len(chunk.Choices) == 0 {
342				continue
343			}
344			for _, choice := range chunk.Choices {
345				if choice.FinishReason != "" {
346					finishReason = choice.FinishReason
347				}
348				switch {
349				case choice.Delta.Content != "":
350					if !isActiveText {
351						isActiveText = true
352						if !yield(fantasy.StreamPart{
353							Type: fantasy.StreamPartTypeTextStart,
354							ID:   "0",
355						}) {
356							return
357						}
358					}
359					if !yield(fantasy.StreamPart{
360						Type:  fantasy.StreamPartTypeTextDelta,
361						ID:    "0",
362						Delta: choice.Delta.Content,
363					}) {
364						return
365					}
366				case len(choice.Delta.ToolCalls) > 0:
367					if isActiveText {
368						isActiveText = false
369						if !yield(fantasy.StreamPart{
370							Type: fantasy.StreamPartTypeTextEnd,
371							ID:   "0",
372						}) {
373							return
374						}
375					}
376
377					for _, toolCallDelta := range choice.Delta.ToolCalls {
378						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379							if existingToolCall.hasFinished {
380								continue
381							}
382							if toolCallDelta.Function.Arguments != "" {
383								existingToolCall.arguments += toolCallDelta.Function.Arguments
384							}
385							if !yield(fantasy.StreamPart{
386								Type:  fantasy.StreamPartTypeToolInputDelta,
387								ID:    existingToolCall.id,
388								Delta: toolCallDelta.Function.Arguments,
389							}) {
390								return
391							}
392							toolCalls[toolCallDelta.Index] = existingToolCall
393							if xjson.IsValid(existingToolCall.arguments) {
394								if !yield(fantasy.StreamPart{
395									Type: fantasy.StreamPartTypeToolInputEnd,
396									ID:   existingToolCall.id,
397								}) {
398									return
399								}
400
401								if !yield(fantasy.StreamPart{
402									Type:          fantasy.StreamPartTypeToolCall,
403									ID:            existingToolCall.id,
404									ToolCallName:  existingToolCall.name,
405									ToolCallInput: existingToolCall.arguments,
406								}) {
407									return
408								}
409								existingToolCall.hasFinished = true
410								toolCalls[toolCallDelta.Index] = existingToolCall
411							}
412						} else {
413							var err error
414							if toolCallDelta.Type != "function" {
415								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416							}
417							if toolCallDelta.ID == "" {
418								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419							}
420							if toolCallDelta.Function.Name == "" {
421								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422							}
423							if err != nil {
424								yield(fantasy.StreamPart{
425									Type:  fantasy.StreamPartTypeError,
426									Error: toProviderErr(stream.Err()),
427								})
428								return
429							}
430
431							if !yield(fantasy.StreamPart{
432								Type:         fantasy.StreamPartTypeToolInputStart,
433								ID:           toolCallDelta.ID,
434								ToolCallName: toolCallDelta.Function.Name,
435							}) {
436								return
437							}
438							toolCalls[toolCallDelta.Index] = streamToolCall{
439								id:        toolCallDelta.ID,
440								name:      toolCallDelta.Function.Name,
441								arguments: toolCallDelta.Function.Arguments,
442							}
443
444							exTc := toolCalls[toolCallDelta.Index]
445							if exTc.arguments != "" {
446								if !yield(fantasy.StreamPart{
447									Type:  fantasy.StreamPartTypeToolInputDelta,
448									ID:    exTc.id,
449									Delta: exTc.arguments,
450								}) {
451									return
452								}
453								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454									if !yield(fantasy.StreamPart{
455										Type: fantasy.StreamPartTypeToolInputEnd,
456										ID:   toolCallDelta.ID,
457									}) {
458										return
459									}
460
461									if !yield(fantasy.StreamPart{
462										Type:          fantasy.StreamPartTypeToolCall,
463										ID:            exTc.id,
464										ToolCallName:  exTc.name,
465										ToolCallInput: exTc.arguments,
466									}) {
467										return
468									}
469									exTc.hasFinished = true
470									toolCalls[toolCallDelta.Index] = exTc
471								}
472							}
473							continue
474						}
475					}
476				}
477
478				if o.streamExtraFunc != nil {
479					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480					if !shouldContinue {
481						return
482					}
483					extraContext = updatedContext
484				}
485			}
486
487			for _, choice := range chunk.Choices {
488				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489					for _, annotation := range annotations {
490						if annotation.Type == "url_citation" {
491							if !yield(fantasy.StreamPart{
492								Type:       fantasy.StreamPartTypeSource,
493								ID:         uuid.NewString(),
494								SourceType: fantasy.SourceTypeURL,
495								URL:        annotation.URLCitation.URL,
496								Title:      annotation.URLCitation.Title,
497							}) {
498								return
499							}
500						}
501					}
502				}
503			}
504		}
505		err := stream.Err()
506		if err == nil || errors.Is(err, io.EOF) {
507			if isActiveText {
508				isActiveText = false
509				if !yield(fantasy.StreamPart{
510					Type: fantasy.StreamPartTypeTextEnd,
511					ID:   "0",
512				}) {
513					return
514				}
515			}
516
517			if len(acc.Choices) > 0 {
518				choice := acc.Choices[0]
519				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521				for _, annotation := range choice.Message.Annotations {
522					if annotation.Type == "url_citation" {
523						if !yield(fantasy.StreamPart{
524							Type:       fantasy.StreamPartTypeSource,
525							ID:         acc.ID,
526							SourceType: fantasy.SourceTypeURL,
527							URL:        annotation.URLCitation.URL,
528							Title:      annotation.URLCitation.Title,
529						}) {
530							return
531						}
532					}
533				}
534			}
535			mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536			if len(acc.Choices) > 0 {
537				choice := acc.Choices[0]
538				if len(choice.Message.ToolCalls) > 0 {
539					mappedFinishReason = fantasy.FinishReasonToolCalls
540				}
541			}
542			yield(fantasy.StreamPart{
543				Type:             fantasy.StreamPartTypeFinish,
544				Usage:            usage,
545				FinishReason:     mappedFinishReason,
546				ProviderMetadata: providerMetadata,
547			})
548			return
549		} else { //nolint: revive
550			yield(fantasy.StreamPart{
551				Type:  fantasy.StreamPartTypeError,
552				Error: toProviderErr(err),
553			})
554			return
555		}
556	}, nil
557}
558
559func isReasoningModel(modelID string) bool {
560	return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561		strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562		strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563		strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564		strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568	return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572	return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573		strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577	return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578		strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579		strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583	for _, tool := range tools {
584		if tool.GetType() == fantasy.ToolTypeFunction {
585			ft, ok := tool.(fantasy.FunctionTool)
586			if !ok {
587				continue
588			}
589			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590				OfFunction: &openai.ChatCompletionFunctionToolParam{
591					Function: shared.FunctionDefinitionParam{
592						Name:        ft.Name,
593						Description: param.NewOpt(ft.Description),
594						Parameters:  openai.FunctionParameters(ft.InputSchema),
595						Strict:      param.NewOpt(false),
596					},
597					Type: "function",
598				},
599			})
600			continue
601		}
602
603		warnings = append(warnings, fantasy.CallWarning{
604			Type:    fantasy.CallWarningTypeUnsupportedTool,
605			Tool:    tool,
606			Message: "tool is not supported",
607		})
608	}
609	if toolChoice == nil {
610		return openAiTools, openAiToolChoice, warnings
611	}
612
613	switch *toolChoice {
614	case fantasy.ToolChoiceAuto:
615		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616			OfAuto: param.NewOpt("auto"),
617		}
618	case fantasy.ToolChoiceNone:
619		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620			OfAuto: param.NewOpt("none"),
621		}
622	default:
623		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
625				Type: "function",
626				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
627					Name: string(*toolChoice),
628				},
629			},
630		}
631	}
632	return openAiTools, openAiToolChoice, warnings
633}
634
635// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
636func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
637	var annotations []openai.ChatCompletionMessageAnnotation
638
639	// Parse the raw JSON to extract annotations
640	var deltaData map[string]any
641	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
642		return annotations
643	}
644
645	// Check if annotations exist in the delta
646	if annotationsData, ok := deltaData["annotations"].([]any); ok {
647		for _, annotationData := range annotationsData {
648			if annotationMap, ok := annotationData.(map[string]any); ok {
649				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
650					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
651						url, urlOk := urlCitationData["url"].(string)
652						title, titleOk := urlCitationData["title"].(string)
653						if urlOk && titleOk {
654							annotation := openai.ChatCompletionMessageAnnotation{
655								Type: "url_citation",
656								URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
657									URL:   url,
658									Title: title,
659								},
660							}
661							annotations = append(annotations, annotation)
662						}
663					}
664				}
665			}
666		}
667	}
668
669	return annotations
670}
671
672// GenerateObject implements fantasy.LanguageModel.
673func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
674	switch o.objectMode {
675	case fantasy.ObjectModeText:
676		return object.GenerateWithText(ctx, o, call)
677	case fantasy.ObjectModeTool:
678		return object.GenerateWithTool(ctx, o, call)
679	default:
680		return o.generateObjectWithJSONMode(ctx, call)
681	}
682}
683
684// StreamObject implements fantasy.LanguageModel.
685func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
686	switch o.objectMode {
687	case fantasy.ObjectModeTool:
688		return object.StreamWithTool(ctx, o, call)
689	case fantasy.ObjectModeText:
690		return object.StreamWithText(ctx, o, call)
691	default:
692		return o.streamObjectWithJSONMode(ctx, call)
693	}
694}
695
696func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
697	jsonSchemaMap := schema.ToMap(call.Schema)
698
699	addAdditionalPropertiesFalse(jsonSchemaMap)
700
701	schemaName := call.SchemaName
702	if schemaName == "" {
703		schemaName = "response"
704	}
705
706	fantasyCall := fantasy.Call{
707		Prompt:           call.Prompt,
708		MaxOutputTokens:  call.MaxOutputTokens,
709		Temperature:      call.Temperature,
710		TopP:             call.TopP,
711		PresencePenalty:  call.PresencePenalty,
712		FrequencyPenalty: call.FrequencyPenalty,
713		ProviderOptions:  call.ProviderOptions,
714	}
715
716	params, warnings, err := o.prepareParams(fantasyCall)
717	if err != nil {
718		return nil, err
719	}
720
721	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
722		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
723			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
724				Name:        schemaName,
725				Description: param.NewOpt(call.SchemaDescription),
726				Schema:      jsonSchemaMap,
727				Strict:      param.NewOpt(true),
728			},
729		},
730	}
731
732	response, err := o.client.Chat.Completions.New(ctx, *params)
733	if err != nil {
734		return nil, toProviderErr(err)
735	}
736
737	if len(response.Choices) == 0 {
738		usage, _ := o.usageFunc(*response)
739		return nil, &fantasy.NoObjectGeneratedError{
740			RawText:      "",
741			ParseError:   fmt.Errorf("no choices in response"),
742			Usage:        usage,
743			FinishReason: fantasy.FinishReasonUnknown,
744		}
745	}
746
747	choice := response.Choices[0]
748	jsonText := choice.Message.Content
749
750	var obj any
751	if call.RepairText != nil {
752		obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
753	} else {
754		obj, err = schema.ParseAndValidate(jsonText, call.Schema)
755	}
756
757	usage, _ := o.usageFunc(*response)
758	finishReason := o.mapFinishReasonFunc(choice.FinishReason)
759
760	if err != nil {
761		if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
762			nogErr.Usage = usage
763			nogErr.FinishReason = finishReason
764		}
765		return nil, err
766	}
767
768	return &fantasy.ObjectResponse{
769		Object:       obj,
770		RawText:      jsonText,
771		Usage:        usage,
772		FinishReason: finishReason,
773		Warnings:     warnings,
774	}, nil
775}
776
777func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
778	jsonSchemaMap := schema.ToMap(call.Schema)
779
780	addAdditionalPropertiesFalse(jsonSchemaMap)
781
782	schemaName := call.SchemaName
783	if schemaName == "" {
784		schemaName = "response"
785	}
786
787	fantasyCall := fantasy.Call{
788		Prompt:           call.Prompt,
789		MaxOutputTokens:  call.MaxOutputTokens,
790		Temperature:      call.Temperature,
791		TopP:             call.TopP,
792		PresencePenalty:  call.PresencePenalty,
793		FrequencyPenalty: call.FrequencyPenalty,
794		ProviderOptions:  call.ProviderOptions,
795	}
796
797	params, warnings, err := o.prepareParams(fantasyCall)
798	if err != nil {
799		return nil, err
800	}
801
802	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
803		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
804			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
805				Name:        schemaName,
806				Description: param.NewOpt(call.SchemaDescription),
807				Schema:      jsonSchemaMap,
808				Strict:      param.NewOpt(true),
809			},
810		},
811	}
812
813	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
814		IncludeUsage: openai.Bool(true),
815	}
816
817	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
818
819	return func(yield func(fantasy.ObjectStreamPart) bool) {
820		if len(warnings) > 0 {
821			if !yield(fantasy.ObjectStreamPart{
822				Type:     fantasy.ObjectStreamPartTypeObject,
823				Warnings: warnings,
824			}) {
825				return
826			}
827		}
828
829		var accumulated string
830		var lastParsedObject any
831		var usage fantasy.Usage
832		var finishReason fantasy.FinishReason
833		var providerMetadata fantasy.ProviderMetadata
834		var streamErr error
835
836		for stream.Next() {
837			chunk := stream.Current()
838
839			// Update usage
840			usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
841
842			if len(chunk.Choices) == 0 {
843				continue
844			}
845
846			choice := chunk.Choices[0]
847			if choice.FinishReason != "" {
848				finishReason = o.mapFinishReasonFunc(choice.FinishReason)
849			}
850
851			if choice.Delta.Content != "" {
852				accumulated += choice.Delta.Content
853
854				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
855
856				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
857					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
858						if !reflect.DeepEqual(obj, lastParsedObject) {
859							if !yield(fantasy.ObjectStreamPart{
860								Type:   fantasy.ObjectStreamPartTypeObject,
861								Object: obj,
862							}) {
863								return
864							}
865							lastParsedObject = obj
866						}
867					}
868				}
869
870				if state == schema.ParseStateFailed && call.RepairText != nil {
871					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
872					if repairErr == nil {
873						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
874						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
875							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
876							if !reflect.DeepEqual(obj2, lastParsedObject) {
877								if !yield(fantasy.ObjectStreamPart{
878									Type:   fantasy.ObjectStreamPartTypeObject,
879									Object: obj2,
880								}) {
881									return
882								}
883								lastParsedObject = obj2
884							}
885						}
886					}
887				}
888			}
889		}
890
891		err := stream.Err()
892		if err != nil && !errors.Is(err, io.EOF) {
893			streamErr = toProviderErr(err)
894			yield(fantasy.ObjectStreamPart{
895				Type:  fantasy.ObjectStreamPartTypeError,
896				Error: streamErr,
897			})
898			return
899		}
900
901		if lastParsedObject != nil {
902			yield(fantasy.ObjectStreamPart{
903				Type:             fantasy.ObjectStreamPartTypeFinish,
904				Usage:            usage,
905				FinishReason:     finishReason,
906				ProviderMetadata: providerMetadata,
907			})
908		} else {
909			yield(fantasy.ObjectStreamPart{
910				Type: fantasy.ObjectStreamPartTypeError,
911				Error: &fantasy.NoObjectGeneratedError{
912					RawText:      accumulated,
913					ParseError:   fmt.Errorf("no valid object generated in stream"),
914					Usage:        usage,
915					FinishReason: finishReason,
916				},
917			})
918		}
919	}, nil
920}
921
922// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
923// This is required by OpenAI's strict mode for structured outputs.
924func addAdditionalPropertiesFalse(schema map[string]any) {
925	if schema["type"] == "object" {
926		if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
927			schema["additionalProperties"] = false
928		}
929
930		// Recursively process nested properties
931		if properties, ok := schema["properties"].(map[string]any); ok {
932			for _, propValue := range properties {
933				if propSchema, ok := propValue.(map[string]any); ok {
934					addAdditionalPropertiesFalse(propSchema)
935				}
936			}
937		}
938	}
939
940	// Handle array items
941	if items, ok := schema["items"].(map[string]any); ok {
942		addAdditionalPropertiesFalse(items)
943	}
944}