language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"reflect"
 10	"strings"
 11
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/object"
 14	"charm.land/fantasy/schema"
 15	xjson "github.com/charmbracelet/x/json"
 16	"github.com/google/uuid"
 17	"github.com/openai/openai-go/v2"
 18	"github.com/openai/openai-go/v2/packages/param"
 19	"github.com/openai/openai-go/v2/shared"
 20)
 21
 22type languageModel struct {
 23	provider                   string
 24	modelID                    string
 25	client                     openai.Client
 26	objectMode                 fantasy.ObjectMode
 27	prepareCallFunc            LanguageModelPrepareCallFunc
 28	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 29	extraContentFunc           LanguageModelExtraContentFunc
 30	usageFunc                  LanguageModelUsageFunc
 31	streamUsageFunc            LanguageModelStreamUsageFunc
 32	streamExtraFunc            LanguageModelStreamExtraFunc
 33	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 34	toPromptFunc               LanguageModelToPromptFunc
 35}
 36
 37// LanguageModelOption is a function that configures a languageModel.
 38type LanguageModelOption = func(*languageModel)
 39
 40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
 41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 42	return func(l *languageModel) {
 43		l.prepareCallFunc = fn
 44	}
 45}
 46
 47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
 48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 49	return func(l *languageModel) {
 50		l.mapFinishReasonFunc = fn
 51	}
 52}
 53
 54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
 55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 56	return func(l *languageModel) {
 57		l.extraContentFunc = fn
 58	}
 59}
 60
 61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
 62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 63	return func(l *languageModel) {
 64		l.streamExtraFunc = fn
 65	}
 66}
 67
 68// WithLanguageModelUsageFunc sets the usage function for the language model.
 69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 70	return func(l *languageModel) {
 71		l.usageFunc = fn
 72	}
 73}
 74
 75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
 76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 77	return func(l *languageModel) {
 78		l.streamUsageFunc = fn
 79	}
 80}
 81
 82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
 83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
 84	return func(l *languageModel) {
 85		l.toPromptFunc = fn
 86	}
 87}
 88
 89// WithLanguageModelObjectMode sets the object generation mode.
 90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
 91	return func(l *languageModel) {
 92		// not supported
 93		if om == fantasy.ObjectModeJSON {
 94			om = fantasy.ObjectModeAuto
 95		}
 96		l.objectMode = om
 97	}
 98}
 99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101	model := languageModel{
102		modelID:                    modelID,
103		provider:                   provider,
104		client:                     client,
105		objectMode:                 fantasy.ObjectModeAuto,
106		prepareCallFunc:            DefaultPrepareCallFunc,
107		mapFinishReasonFunc:        DefaultMapFinishReasonFunc,
108		usageFunc:                  DefaultUsageFunc,
109		streamUsageFunc:            DefaultStreamUsageFunc,
110		streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111		toPromptFunc:               DefaultToPrompt,
112	}
113
114	for _, o := range opts {
115		o(&model)
116	}
117	return model
118}
119
120type streamToolCall struct {
121	id          string
122	name        string
123	arguments   string
124	hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129	return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134	return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138	params := &openai.ChatCompletionNewParams{}
139	messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140	if call.TopK != nil {
141		warnings = append(warnings, fantasy.CallWarning{
142			Type:    fantasy.CallWarningTypeUnsupportedSetting,
143			Setting: "top_k",
144		})
145	}
146
147	if call.MaxOutputTokens != nil {
148		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149	}
150	if call.Temperature != nil {
151		params.Temperature = param.NewOpt(*call.Temperature)
152	}
153	if call.TopP != nil {
154		params.TopP = param.NewOpt(*call.TopP)
155	}
156	if call.FrequencyPenalty != nil {
157		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158	}
159	if call.PresencePenalty != nil {
160		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161	}
162
163	if isReasoningModel(o.modelID) {
164		// remove unsupported settings for reasoning models
165		// see https://platform.openai.com/docs/guides/reasoning#limitations
166		if call.Temperature != nil {
167			params.Temperature = param.Opt[float64]{}
168			warnings = append(warnings, fantasy.CallWarning{
169				Type:    fantasy.CallWarningTypeUnsupportedSetting,
170				Setting: "temperature",
171				Details: "temperature is not supported for reasoning models",
172			})
173		}
174		if call.TopP != nil {
175			params.TopP = param.Opt[float64]{}
176			warnings = append(warnings, fantasy.CallWarning{
177				Type:    fantasy.CallWarningTypeUnsupportedSetting,
178				Setting: "TopP",
179				Details: "TopP is not supported for reasoning models",
180			})
181		}
182		if call.FrequencyPenalty != nil {
183			params.FrequencyPenalty = param.Opt[float64]{}
184			warnings = append(warnings, fantasy.CallWarning{
185				Type:    fantasy.CallWarningTypeUnsupportedSetting,
186				Setting: "FrequencyPenalty",
187				Details: "FrequencyPenalty is not supported for reasoning models",
188			})
189		}
190		if call.PresencePenalty != nil {
191			params.PresencePenalty = param.Opt[float64]{}
192			warnings = append(warnings, fantasy.CallWarning{
193				Type:    fantasy.CallWarningTypeUnsupportedSetting,
194				Setting: "PresencePenalty",
195				Details: "PresencePenalty is not supported for reasoning models",
196			})
197		}
198
199		// reasoning models use max_completion_tokens instead of max_tokens
200		if call.MaxOutputTokens != nil {
201			if !params.MaxCompletionTokens.Valid() {
202				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203			}
204			params.MaxTokens = param.Opt[int64]{}
205		}
206	}
207
208	// Handle search preview models
209	if isSearchPreviewModel(o.modelID) {
210		if call.Temperature != nil {
211			params.Temperature = param.Opt[float64]{}
212			warnings = append(warnings, fantasy.CallWarning{
213				Type:    fantasy.CallWarningTypeUnsupportedSetting,
214				Setting: "temperature",
215				Details: "temperature is not supported for the search preview models and has been removed.",
216			})
217		}
218	}
219
220	optionsWarnings, err := o.prepareCallFunc(o, params, call)
221	if err != nil {
222		return nil, nil, err
223	}
224
225	if len(optionsWarnings) > 0 {
226		warnings = append(warnings, optionsWarnings...)
227	}
228
229	params.Messages = messages
230	params.Model = o.modelID
231
232	if len(call.Tools) > 0 {
233		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234		params.Tools = tools
235		if toolChoice != nil {
236			params.ToolChoice = *toolChoice
237		}
238		warnings = append(warnings, toolWarnings...)
239	}
240	return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245	params, warnings, err := o.prepareParams(call)
246	if err != nil {
247		return nil, err
248	}
249	response, err := o.client.Chat.Completions.New(ctx, *params)
250	if err != nil {
251		return nil, toProviderErr(err)
252	}
253
254	if len(response.Choices) == 0 {
255		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256	}
257	choice := response.Choices[0]
258	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259	text := choice.Message.Content
260	if text != "" {
261		content = append(content, fantasy.TextContent{
262			Text: text,
263		})
264	}
265	if o.extraContentFunc != nil {
266		extraContent := o.extraContentFunc(choice)
267		content = append(content, extraContent...)
268	}
269	for _, tc := range choice.Message.ToolCalls {
270		toolCallID := tc.ID
271		content = append(content, fantasy.ToolCallContent{
272			ProviderExecuted: false,
273			ToolCallID:       toolCallID,
274			ToolName:         tc.Function.Name,
275			Input:            tc.Function.Arguments,
276		})
277	}
278	for _, annotation := range choice.Message.Annotations {
279		if annotation.Type == "url_citation" {
280			content = append(content, fantasy.SourceContent{
281				SourceType: fantasy.SourceTypeURL,
282				ID:         uuid.NewString(),
283				URL:        annotation.URLCitation.URL,
284				Title:      annotation.URLCitation.Title,
285			})
286		}
287	}
288
289	usage, providerMetadata := o.usageFunc(*response)
290
291	mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292	if len(choice.Message.ToolCalls) > 0 {
293		mappedFinishReason = fantasy.FinishReasonToolCalls
294	}
295	return &fantasy.Response{
296		Content:      content,
297		Usage:        usage,
298		FinishReason: mappedFinishReason,
299		ProviderMetadata: fantasy.ProviderMetadata{
300			Name: providerMetadata,
301		},
302		Warnings: warnings,
303	}, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308	params, warnings, err := o.prepareParams(call)
309	if err != nil {
310		return nil, err
311	}
312
313	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314		IncludeUsage: openai.Bool(true),
315	}
316
317	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318	isActiveText := false
319	toolCalls := make(map[int64]streamToolCall)
320
321	providerMetadata := fantasy.ProviderMetadata{
322		Name: &ProviderMetadata{},
323	}
324	acc := openai.ChatCompletionAccumulator{}
325	extraContext := make(map[string]any)
326	var usage fantasy.Usage
327	var finishReason string
328	return func(yield func(fantasy.StreamPart) bool) {
329		if len(warnings) > 0 {
330			if !yield(fantasy.StreamPart{
331				Type:     fantasy.StreamPartTypeWarnings,
332				Warnings: warnings,
333			}) {
334				return
335			}
336		}
337		for stream.Next() {
338			chunk := stream.Current()
339			acc.AddChunk(chunk)
340			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341			if len(chunk.Choices) == 0 {
342				continue
343			}
344			for _, choice := range chunk.Choices {
345				if choice.FinishReason != "" {
346					finishReason = choice.FinishReason
347				}
348				switch {
349				case choice.Delta.Content != "":
350					if !isActiveText {
351						isActiveText = true
352						if !yield(fantasy.StreamPart{
353							Type: fantasy.StreamPartTypeTextStart,
354							ID:   "0",
355						}) {
356							return
357						}
358					}
359					if !yield(fantasy.StreamPart{
360						Type:  fantasy.StreamPartTypeTextDelta,
361						ID:    "0",
362						Delta: choice.Delta.Content,
363					}) {
364						return
365					}
366				case len(choice.Delta.ToolCalls) > 0:
367					if isActiveText {
368						isActiveText = false
369						if !yield(fantasy.StreamPart{
370							Type: fantasy.StreamPartTypeTextEnd,
371							ID:   "0",
372						}) {
373							return
374						}
375					}
376
377					for _, toolCallDelta := range choice.Delta.ToolCalls {
378						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379							if existingToolCall.hasFinished {
380								continue
381							}
382							if toolCallDelta.Function.Arguments != "" {
383								existingToolCall.arguments += toolCallDelta.Function.Arguments
384							}
385							if !yield(fantasy.StreamPart{
386								Type:  fantasy.StreamPartTypeToolInputDelta,
387								ID:    existingToolCall.id,
388								Delta: toolCallDelta.Function.Arguments,
389							}) {
390								return
391							}
392							toolCalls[toolCallDelta.Index] = existingToolCall
393							if xjson.IsValid(existingToolCall.arguments) {
394								if !yield(fantasy.StreamPart{
395									Type: fantasy.StreamPartTypeToolInputEnd,
396									ID:   existingToolCall.id,
397								}) {
398									return
399								}
400
401								if !yield(fantasy.StreamPart{
402									Type:          fantasy.StreamPartTypeToolCall,
403									ID:            existingToolCall.id,
404									ToolCallName:  existingToolCall.name,
405									ToolCallInput: existingToolCall.arguments,
406								}) {
407									return
408								}
409								existingToolCall.hasFinished = true
410								toolCalls[toolCallDelta.Index] = existingToolCall
411							}
412						} else {
413							var err error
414							if toolCallDelta.Type != "function" {
415								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416							}
417							if toolCallDelta.ID == "" {
418								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419							}
420							if toolCallDelta.Function.Name == "" {
421								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422							}
423							if err != nil {
424								yield(fantasy.StreamPart{
425									Type:  fantasy.StreamPartTypeError,
426									Error: toProviderErr(stream.Err()),
427								})
428								return
429							}
430
431							if !yield(fantasy.StreamPart{
432								Type:         fantasy.StreamPartTypeToolInputStart,
433								ID:           toolCallDelta.ID,
434								ToolCallName: toolCallDelta.Function.Name,
435							}) {
436								return
437							}
438							toolCalls[toolCallDelta.Index] = streamToolCall{
439								id:        toolCallDelta.ID,
440								name:      toolCallDelta.Function.Name,
441								arguments: toolCallDelta.Function.Arguments,
442							}
443
444							exTc := toolCalls[toolCallDelta.Index]
445							if exTc.arguments != "" {
446								if !yield(fantasy.StreamPart{
447									Type:  fantasy.StreamPartTypeToolInputDelta,
448									ID:    exTc.id,
449									Delta: exTc.arguments,
450								}) {
451									return
452								}
453								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454									if !yield(fantasy.StreamPart{
455										Type: fantasy.StreamPartTypeToolInputEnd,
456										ID:   toolCallDelta.ID,
457									}) {
458										return
459									}
460
461									if !yield(fantasy.StreamPart{
462										Type:          fantasy.StreamPartTypeToolCall,
463										ID:            exTc.id,
464										ToolCallName:  exTc.name,
465										ToolCallInput: exTc.arguments,
466									}) {
467										return
468									}
469									exTc.hasFinished = true
470									toolCalls[toolCallDelta.Index] = exTc
471								}
472							}
473							continue
474						}
475					}
476				}
477
478				if o.streamExtraFunc != nil {
479					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480					if !shouldContinue {
481						return
482					}
483					extraContext = updatedContext
484				}
485			}
486
487			for _, choice := range chunk.Choices {
488				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489					for _, annotation := range annotations {
490						if annotation.Type == "url_citation" {
491							if !yield(fantasy.StreamPart{
492								Type:       fantasy.StreamPartTypeSource,
493								ID:         uuid.NewString(),
494								SourceType: fantasy.SourceTypeURL,
495								URL:        annotation.URLCitation.URL,
496								Title:      annotation.URLCitation.Title,
497							}) {
498								return
499							}
500						}
501					}
502				}
503			}
504		}
505		err := stream.Err()
506		if err == nil || errors.Is(err, io.EOF) {
507			if isActiveText {
508				isActiveText = false
509				if !yield(fantasy.StreamPart{
510					Type: fantasy.StreamPartTypeTextEnd,
511					ID:   "0",
512				}) {
513					return
514				}
515			}
516
517			if len(acc.Choices) > 0 {
518				choice := acc.Choices[0]
519				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521				for _, annotation := range choice.Message.Annotations {
522					if annotation.Type == "url_citation" {
523						if !yield(fantasy.StreamPart{
524							Type:       fantasy.StreamPartTypeSource,
525							ID:         acc.ID,
526							SourceType: fantasy.SourceTypeURL,
527							URL:        annotation.URLCitation.URL,
528							Title:      annotation.URLCitation.Title,
529						}) {
530							return
531						}
532					}
533				}
534			}
535			mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536			if len(acc.Choices) > 0 {
537				choice := acc.Choices[0]
538				if len(choice.Message.ToolCalls) > 0 {
539					mappedFinishReason = fantasy.FinishReasonToolCalls
540				}
541			}
542			yield(fantasy.StreamPart{
543				Type:             fantasy.StreamPartTypeFinish,
544				Usage:            usage,
545				FinishReason:     mappedFinishReason,
546				ProviderMetadata: providerMetadata,
547			})
548			return
549		} else { //nolint: revive
550			yield(fantasy.StreamPart{
551				Type:  fantasy.StreamPartTypeError,
552				Error: toProviderErr(err),
553			})
554			return
555		}
556	}, nil
557}
558
559func isReasoningModel(modelID string) bool {
560	return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561		strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562		strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563		strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564		strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568	return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572	return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573		strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577	return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578		strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579		strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583	for _, tool := range tools {
584		if tool.GetType() == fantasy.ToolTypeFunction {
585			ft, ok := tool.(fantasy.FunctionTool)
586			if !ok {
587				continue
588			}
589			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590				OfFunction: &openai.ChatCompletionFunctionToolParam{
591					Function: shared.FunctionDefinitionParam{
592						Name:        ft.Name,
593						Description: param.NewOpt(ft.Description),
594						Parameters:  openai.FunctionParameters(ft.InputSchema),
595						Strict:      param.NewOpt(false),
596					},
597					Type: "function",
598				},
599			})
600			continue
601		}
602
603		warnings = append(warnings, fantasy.CallWarning{
604			Type:    fantasy.CallWarningTypeUnsupportedTool,
605			Tool:    tool,
606			Message: "tool is not supported",
607		})
608	}
609	if toolChoice == nil {
610		return openAiTools, openAiToolChoice, warnings
611	}
612
613	switch *toolChoice {
614	case fantasy.ToolChoiceAuto:
615		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616			OfAuto: param.NewOpt("auto"),
617		}
618	case fantasy.ToolChoiceNone:
619		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620			OfAuto: param.NewOpt("none"),
621		}
622	default:
623		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
625				Type: "function",
626				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
627					Name: string(*toolChoice),
628				},
629			},
630		}
631	}
632	return openAiTools, openAiToolChoice, warnings
633}
634
635// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
636func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
637	var annotations []openai.ChatCompletionMessageAnnotation
638
639	// Parse the raw JSON to extract annotations
640	var deltaData map[string]any
641	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
642		return annotations
643	}
644
645	// Check if annotations exist in the delta
646	if annotationsData, ok := deltaData["annotations"].([]any); ok {
647		for _, annotationData := range annotationsData {
648			if annotationMap, ok := annotationData.(map[string]any); ok {
649				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
650					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
651						annotation := openai.ChatCompletionMessageAnnotation{
652							Type: "url_citation",
653							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
654								URL:   urlCitationData["url"].(string),
655								Title: urlCitationData["title"].(string),
656							},
657						}
658						annotations = append(annotations, annotation)
659					}
660				}
661			}
662		}
663	}
664
665	return annotations
666}
667
668// GenerateObject implements fantasy.LanguageModel.
669func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
670	switch o.objectMode {
671	case fantasy.ObjectModeText:
672		return object.GenerateWithText(ctx, o, call)
673	case fantasy.ObjectModeTool:
674		return object.GenerateWithTool(ctx, o, call)
675	default:
676		return o.generateObjectWithJSONMode(ctx, call)
677	}
678}
679
680// StreamObject implements fantasy.LanguageModel.
681func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
682	switch o.objectMode {
683	case fantasy.ObjectModeTool:
684		return object.StreamWithTool(ctx, o, call)
685	case fantasy.ObjectModeText:
686		return object.StreamWithText(ctx, o, call)
687	default:
688		return o.streamObjectWithJSONMode(ctx, call)
689	}
690}
691
692func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
693	jsonSchemaMap := schema.ToMap(call.Schema)
694
695	addAdditionalPropertiesFalse(jsonSchemaMap)
696
697	schemaName := call.SchemaName
698	if schemaName == "" {
699		schemaName = "response"
700	}
701
702	fantasyCall := fantasy.Call{
703		Prompt:           call.Prompt,
704		MaxOutputTokens:  call.MaxOutputTokens,
705		Temperature:      call.Temperature,
706		TopP:             call.TopP,
707		PresencePenalty:  call.PresencePenalty,
708		FrequencyPenalty: call.FrequencyPenalty,
709		ProviderOptions:  call.ProviderOptions,
710	}
711
712	params, warnings, err := o.prepareParams(fantasyCall)
713	if err != nil {
714		return nil, err
715	}
716
717	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
718		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
719			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
720				Name:        schemaName,
721				Description: param.NewOpt(call.SchemaDescription),
722				Schema:      jsonSchemaMap,
723				Strict:      param.NewOpt(true),
724			},
725		},
726	}
727
728	response, err := o.client.Chat.Completions.New(ctx, *params)
729	if err != nil {
730		return nil, toProviderErr(err)
731	}
732
733	if len(response.Choices) == 0 {
734		usage, _ := o.usageFunc(*response)
735		return nil, &fantasy.NoObjectGeneratedError{
736			RawText:      "",
737			ParseError:   fmt.Errorf("no choices in response"),
738			Usage:        usage,
739			FinishReason: fantasy.FinishReasonUnknown,
740		}
741	}
742
743	choice := response.Choices[0]
744	jsonText := choice.Message.Content
745
746	var obj any
747	if call.RepairText != nil {
748		obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
749	} else {
750		obj, err = schema.ParseAndValidate(jsonText, call.Schema)
751	}
752
753	usage, _ := o.usageFunc(*response)
754	finishReason := o.mapFinishReasonFunc(choice.FinishReason)
755
756	if err != nil {
757		if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
758			nogErr.Usage = usage
759			nogErr.FinishReason = finishReason
760		}
761		return nil, err
762	}
763
764	return &fantasy.ObjectResponse{
765		Object:       obj,
766		RawText:      jsonText,
767		Usage:        usage,
768		FinishReason: finishReason,
769		Warnings:     warnings,
770	}, nil
771}
772
773func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
774	jsonSchemaMap := schema.ToMap(call.Schema)
775
776	addAdditionalPropertiesFalse(jsonSchemaMap)
777
778	schemaName := call.SchemaName
779	if schemaName == "" {
780		schemaName = "response"
781	}
782
783	fantasyCall := fantasy.Call{
784		Prompt:           call.Prompt,
785		MaxOutputTokens:  call.MaxOutputTokens,
786		Temperature:      call.Temperature,
787		TopP:             call.TopP,
788		PresencePenalty:  call.PresencePenalty,
789		FrequencyPenalty: call.FrequencyPenalty,
790		ProviderOptions:  call.ProviderOptions,
791	}
792
793	params, warnings, err := o.prepareParams(fantasyCall)
794	if err != nil {
795		return nil, err
796	}
797
798	params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
799		OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
800			JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
801				Name:        schemaName,
802				Description: param.NewOpt(call.SchemaDescription),
803				Schema:      jsonSchemaMap,
804				Strict:      param.NewOpt(true),
805			},
806		},
807	}
808
809	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
810		IncludeUsage: openai.Bool(true),
811	}
812
813	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
814
815	return func(yield func(fantasy.ObjectStreamPart) bool) {
816		if len(warnings) > 0 {
817			if !yield(fantasy.ObjectStreamPart{
818				Type:     fantasy.ObjectStreamPartTypeObject,
819				Warnings: warnings,
820			}) {
821				return
822			}
823		}
824
825		var accumulated string
826		var lastParsedObject any
827		var usage fantasy.Usage
828		var finishReason fantasy.FinishReason
829		var providerMetadata fantasy.ProviderMetadata
830		var streamErr error
831
832		for stream.Next() {
833			chunk := stream.Current()
834
835			// Update usage
836			usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
837
838			if len(chunk.Choices) == 0 {
839				continue
840			}
841
842			choice := chunk.Choices[0]
843			if choice.FinishReason != "" {
844				finishReason = o.mapFinishReasonFunc(choice.FinishReason)
845			}
846
847			if choice.Delta.Content != "" {
848				accumulated += choice.Delta.Content
849
850				obj, state, parseErr := schema.ParsePartialJSON(accumulated)
851
852				if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
853					if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
854						if !reflect.DeepEqual(obj, lastParsedObject) {
855							if !yield(fantasy.ObjectStreamPart{
856								Type:   fantasy.ObjectStreamPartTypeObject,
857								Object: obj,
858							}) {
859								return
860							}
861							lastParsedObject = obj
862						}
863					}
864				}
865
866				if state == schema.ParseStateFailed && call.RepairText != nil {
867					repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
868					if repairErr == nil {
869						obj2, state2, _ := schema.ParsePartialJSON(repairedText)
870						if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
871							schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
872							if !reflect.DeepEqual(obj2, lastParsedObject) {
873								if !yield(fantasy.ObjectStreamPart{
874									Type:   fantasy.ObjectStreamPartTypeObject,
875									Object: obj2,
876								}) {
877									return
878								}
879								lastParsedObject = obj2
880							}
881						}
882					}
883				}
884			}
885		}
886
887		err := stream.Err()
888		if err != nil && !errors.Is(err, io.EOF) {
889			streamErr = toProviderErr(err)
890			yield(fantasy.ObjectStreamPart{
891				Type:  fantasy.ObjectStreamPartTypeError,
892				Error: streamErr,
893			})
894			return
895		}
896
897		if lastParsedObject != nil {
898			yield(fantasy.ObjectStreamPart{
899				Type:             fantasy.ObjectStreamPartTypeFinish,
900				Usage:            usage,
901				FinishReason:     finishReason,
902				ProviderMetadata: providerMetadata,
903			})
904		} else {
905			yield(fantasy.ObjectStreamPart{
906				Type: fantasy.ObjectStreamPartTypeError,
907				Error: &fantasy.NoObjectGeneratedError{
908					RawText:      accumulated,
909					ParseError:   fmt.Errorf("no valid object generated in stream"),
910					Usage:        usage,
911					FinishReason: finishReason,
912				},
913			})
914		}
915	}, nil
916}
917
918// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
919// This is required by OpenAI's strict mode for structured outputs.
920func addAdditionalPropertiesFalse(schema map[string]any) {
921	if schema["type"] == "object" {
922		if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
923			schema["additionalProperties"] = false
924		}
925
926		// Recursively process nested properties
927		if properties, ok := schema["properties"].(map[string]any); ok {
928			for _, propValue := range properties {
929				if propSchema, ok := propValue.(map[string]any); ok {
930					addAdditionalPropertiesFalse(propSchema)
931				}
932			}
933		}
934	}
935
936	// Handle array items
937	if items, ok := schema["items"].(map[string]any); ok {
938		addAdditionalPropertiesFalse(items)
939	}
940}