1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "strings"
11
12 "charm.land/fantasy"
13 "charm.land/fantasy/object"
14 "charm.land/fantasy/schema"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type languageModel struct {
23 provider string
24 modelID string
25 client openai.Client
26 objectMode fantasy.ObjectMode
27 prepareCallFunc LanguageModelPrepareCallFunc
28 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
29 extraContentFunc LanguageModelExtraContentFunc
30 usageFunc LanguageModelUsageFunc
31 streamUsageFunc LanguageModelStreamUsageFunc
32 streamExtraFunc LanguageModelStreamExtraFunc
33 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
34 toPromptFunc LanguageModelToPromptFunc
35 noDefaultUserAgent bool
36}
37
38// LanguageModelOption is a function that configures a languageModel.
39type LanguageModelOption = func(*languageModel)
40
41// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
42func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
43 return func(l *languageModel) {
44 l.prepareCallFunc = fn
45 }
46}
47
48// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
49func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
50 return func(l *languageModel) {
51 l.mapFinishReasonFunc = fn
52 }
53}
54
55// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
56func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
57 return func(l *languageModel) {
58 l.extraContentFunc = fn
59 }
60}
61
62// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
63func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
64 return func(l *languageModel) {
65 l.streamExtraFunc = fn
66 }
67}
68
69// WithLanguageModelUsageFunc sets the usage function for the language model.
70func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
71 return func(l *languageModel) {
72 l.usageFunc = fn
73 }
74}
75
76// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
77func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
78 return func(l *languageModel) {
79 l.streamUsageFunc = fn
80 }
81}
82
83// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
84func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
85 return func(l *languageModel) {
86 l.toPromptFunc = fn
87 }
88}
89
90// WithLanguageModelSkipUserAgent prevents per-call User-Agent overrides. This
91// exists solely for OpenRouter, which rejects User-Agent overrides.
92//
93// This function is provisional and may be removed in a future release.
94func WithLanguageModelSkipUserAgent() LanguageModelOption {
95 return func(l *languageModel) {
96 l.noDefaultUserAgent = true
97 }
98}
99
100// WithLanguageModelObjectMode sets the object generation mode.
101func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
102 return func(l *languageModel) {
103 // not supported
104 if om == fantasy.ObjectModeJSON {
105 om = fantasy.ObjectModeAuto
106 }
107 l.objectMode = om
108 }
109}
110
111func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
112 model := languageModel{
113 modelID: modelID,
114 provider: provider,
115 client: client,
116 objectMode: fantasy.ObjectModeAuto,
117 prepareCallFunc: DefaultPrepareCallFunc,
118 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
119 usageFunc: DefaultUsageFunc,
120 streamUsageFunc: DefaultStreamUsageFunc,
121 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
122 toPromptFunc: DefaultToPrompt,
123 }
124
125 for _, o := range opts {
126 o(&model)
127 }
128 return model
129}
130
131type streamToolCall struct {
132 id string
133 name string
134 arguments string
135 hasFinished bool
136}
137
138// Model implements fantasy.LanguageModel.
139func (o languageModel) Model() string {
140 return o.modelID
141}
142
143// Provider implements fantasy.LanguageModel.
144func (o languageModel) Provider() string {
145 return o.provider
146}
147
148func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
149 params := &openai.ChatCompletionNewParams{}
150 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
151 if call.TopK != nil {
152 warnings = append(warnings, fantasy.CallWarning{
153 Type: fantasy.CallWarningTypeUnsupportedSetting,
154 Setting: "top_k",
155 })
156 }
157
158 if call.MaxOutputTokens != nil {
159 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
160 }
161 if call.Temperature != nil {
162 params.Temperature = param.NewOpt(*call.Temperature)
163 }
164 if call.TopP != nil {
165 params.TopP = param.NewOpt(*call.TopP)
166 }
167 if call.FrequencyPenalty != nil {
168 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
169 }
170 if call.PresencePenalty != nil {
171 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
172 }
173
174 if isReasoningModel(o.modelID) {
175 // remove unsupported settings for reasoning models
176 // see https://platform.openai.com/docs/guides/reasoning#limitations
177 if call.Temperature != nil {
178 params.Temperature = param.Opt[float64]{}
179 warnings = append(warnings, fantasy.CallWarning{
180 Type: fantasy.CallWarningTypeUnsupportedSetting,
181 Setting: "temperature",
182 Details: "temperature is not supported for reasoning models",
183 })
184 }
185 if call.TopP != nil {
186 params.TopP = param.Opt[float64]{}
187 warnings = append(warnings, fantasy.CallWarning{
188 Type: fantasy.CallWarningTypeUnsupportedSetting,
189 Setting: "TopP",
190 Details: "TopP is not supported for reasoning models",
191 })
192 }
193 if call.FrequencyPenalty != nil {
194 params.FrequencyPenalty = param.Opt[float64]{}
195 warnings = append(warnings, fantasy.CallWarning{
196 Type: fantasy.CallWarningTypeUnsupportedSetting,
197 Setting: "FrequencyPenalty",
198 Details: "FrequencyPenalty is not supported for reasoning models",
199 })
200 }
201 if call.PresencePenalty != nil {
202 params.PresencePenalty = param.Opt[float64]{}
203 warnings = append(warnings, fantasy.CallWarning{
204 Type: fantasy.CallWarningTypeUnsupportedSetting,
205 Setting: "PresencePenalty",
206 Details: "PresencePenalty is not supported for reasoning models",
207 })
208 }
209
210 // reasoning models use max_completion_tokens instead of max_tokens
211 if call.MaxOutputTokens != nil {
212 if !params.MaxCompletionTokens.Valid() {
213 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
214 }
215 params.MaxTokens = param.Opt[int64]{}
216 }
217 }
218
219 // Handle search preview models
220 if isSearchPreviewModel(o.modelID) {
221 if call.Temperature != nil {
222 params.Temperature = param.Opt[float64]{}
223 warnings = append(warnings, fantasy.CallWarning{
224 Type: fantasy.CallWarningTypeUnsupportedSetting,
225 Setting: "temperature",
226 Details: "temperature is not supported for the search preview models and has been removed.",
227 })
228 }
229 }
230
231 optionsWarnings, err := o.prepareCallFunc(o, params, call)
232 if err != nil {
233 return nil, nil, err
234 }
235
236 if len(optionsWarnings) > 0 {
237 warnings = append(warnings, optionsWarnings...)
238 }
239
240 params.Messages = messages
241 params.Model = o.modelID
242
243 if len(call.Tools) > 0 {
244 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
245 params.Tools = tools
246 if toolChoice != nil {
247 params.ToolChoice = *toolChoice
248 }
249 warnings = append(warnings, toolWarnings...)
250 }
251 return params, warnings, nil
252}
253
254// Generate implements fantasy.LanguageModel.
255func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
256 params, warnings, err := o.prepareParams(call)
257 if err != nil {
258 return nil, err
259 }
260 response, err := o.client.Chat.Completions.New(ctx, *params, callUARequestOptions(call, o.noDefaultUserAgent)...)
261 if err != nil {
262 return nil, toProviderErr(err)
263 }
264
265 if len(response.Choices) == 0 {
266 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
267 }
268 choice := response.Choices[0]
269 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
270 text := choice.Message.Content
271 if text != "" {
272 content = append(content, fantasy.TextContent{
273 Text: text,
274 })
275 }
276 if o.extraContentFunc != nil {
277 extraContent := o.extraContentFunc(choice)
278 content = append(content, extraContent...)
279 }
280 for _, tc := range choice.Message.ToolCalls {
281 toolCallID := tc.ID
282 content = append(content, fantasy.ToolCallContent{
283 ProviderExecuted: false,
284 ToolCallID: toolCallID,
285 ToolName: tc.Function.Name,
286 Input: tc.Function.Arguments,
287 })
288 }
289 for _, annotation := range choice.Message.Annotations {
290 if annotation.Type == "url_citation" {
291 content = append(content, fantasy.SourceContent{
292 SourceType: fantasy.SourceTypeURL,
293 ID: uuid.NewString(),
294 URL: annotation.URLCitation.URL,
295 Title: annotation.URLCitation.Title,
296 })
297 }
298 }
299
300 usage, providerMetadata := o.usageFunc(*response)
301
302 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
303 if len(choice.Message.ToolCalls) > 0 {
304 mappedFinishReason = fantasy.FinishReasonToolCalls
305 }
306 return &fantasy.Response{
307 Content: content,
308 Usage: usage,
309 FinishReason: mappedFinishReason,
310 ProviderMetadata: fantasy.ProviderMetadata{
311 Name: providerMetadata,
312 },
313 Warnings: warnings,
314 }, nil
315}
316
317// Stream implements fantasy.LanguageModel.
318func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
319 params, warnings, err := o.prepareParams(call)
320 if err != nil {
321 return nil, err
322 }
323
324 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
325 IncludeUsage: openai.Bool(true),
326 }
327
328 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, callUARequestOptions(call, o.noDefaultUserAgent)...)
329 isActiveText := false
330 toolCalls := make(map[int64]streamToolCall)
331
332 providerMetadata := fantasy.ProviderMetadata{
333 Name: &ProviderMetadata{},
334 }
335 acc := openai.ChatCompletionAccumulator{}
336 extraContext := make(map[string]any)
337 var usage fantasy.Usage
338 var finishReason string
339 return func(yield func(fantasy.StreamPart) bool) {
340 if len(warnings) > 0 {
341 if !yield(fantasy.StreamPart{
342 Type: fantasy.StreamPartTypeWarnings,
343 Warnings: warnings,
344 }) {
345 return
346 }
347 }
348 for stream.Next() {
349 chunk := stream.Current()
350 acc.AddChunk(chunk)
351 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
352 if len(chunk.Choices) == 0 {
353 continue
354 }
355 for _, choice := range chunk.Choices {
356 if choice.FinishReason != "" {
357 finishReason = choice.FinishReason
358 }
359 switch {
360 case choice.Delta.Content != "":
361 if !isActiveText {
362 isActiveText = true
363 if !yield(fantasy.StreamPart{
364 Type: fantasy.StreamPartTypeTextStart,
365 ID: "0",
366 }) {
367 return
368 }
369 }
370 if !yield(fantasy.StreamPart{
371 Type: fantasy.StreamPartTypeTextDelta,
372 ID: "0",
373 Delta: choice.Delta.Content,
374 }) {
375 return
376 }
377 case len(choice.Delta.ToolCalls) > 0:
378 if isActiveText {
379 isActiveText = false
380 if !yield(fantasy.StreamPart{
381 Type: fantasy.StreamPartTypeTextEnd,
382 ID: "0",
383 }) {
384 return
385 }
386 }
387
388 for _, toolCallDelta := range choice.Delta.ToolCalls {
389 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
390 if existingToolCall.hasFinished {
391 continue
392 }
393 if toolCallDelta.Function.Arguments != "" {
394 existingToolCall.arguments += toolCallDelta.Function.Arguments
395 }
396 if !yield(fantasy.StreamPart{
397 Type: fantasy.StreamPartTypeToolInputDelta,
398 ID: existingToolCall.id,
399 Delta: toolCallDelta.Function.Arguments,
400 }) {
401 return
402 }
403 toolCalls[toolCallDelta.Index] = existingToolCall
404 if xjson.IsValid(existingToolCall.arguments) {
405 if !yield(fantasy.StreamPart{
406 Type: fantasy.StreamPartTypeToolInputEnd,
407 ID: existingToolCall.id,
408 }) {
409 return
410 }
411
412 if !yield(fantasy.StreamPart{
413 Type: fantasy.StreamPartTypeToolCall,
414 ID: existingToolCall.id,
415 ToolCallName: existingToolCall.name,
416 ToolCallInput: existingToolCall.arguments,
417 }) {
418 return
419 }
420 existingToolCall.hasFinished = true
421 toolCalls[toolCallDelta.Index] = existingToolCall
422 }
423 } else {
424 var err error
425 if toolCallDelta.Type != "function" {
426 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
427 }
428 if toolCallDelta.ID == "" {
429 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
430 }
431 if toolCallDelta.Function.Name == "" {
432 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
433 }
434 if err != nil {
435 yield(fantasy.StreamPart{
436 Type: fantasy.StreamPartTypeError,
437 Error: toProviderErr(stream.Err()),
438 })
439 return
440 }
441
442 if !yield(fantasy.StreamPart{
443 Type: fantasy.StreamPartTypeToolInputStart,
444 ID: toolCallDelta.ID,
445 ToolCallName: toolCallDelta.Function.Name,
446 }) {
447 return
448 }
449 toolCalls[toolCallDelta.Index] = streamToolCall{
450 id: toolCallDelta.ID,
451 name: toolCallDelta.Function.Name,
452 arguments: toolCallDelta.Function.Arguments,
453 }
454
455 exTc := toolCalls[toolCallDelta.Index]
456 if exTc.arguments != "" {
457 if !yield(fantasy.StreamPart{
458 Type: fantasy.StreamPartTypeToolInputDelta,
459 ID: exTc.id,
460 Delta: exTc.arguments,
461 }) {
462 return
463 }
464 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
465 if !yield(fantasy.StreamPart{
466 Type: fantasy.StreamPartTypeToolInputEnd,
467 ID: toolCallDelta.ID,
468 }) {
469 return
470 }
471
472 if !yield(fantasy.StreamPart{
473 Type: fantasy.StreamPartTypeToolCall,
474 ID: exTc.id,
475 ToolCallName: exTc.name,
476 ToolCallInput: exTc.arguments,
477 }) {
478 return
479 }
480 exTc.hasFinished = true
481 toolCalls[toolCallDelta.Index] = exTc
482 }
483 }
484 continue
485 }
486 }
487 }
488
489 if o.streamExtraFunc != nil {
490 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
491 if !shouldContinue {
492 return
493 }
494 extraContext = updatedContext
495 }
496 }
497
498 for _, choice := range chunk.Choices {
499 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
500 for _, annotation := range annotations {
501 if annotation.Type == "url_citation" {
502 if !yield(fantasy.StreamPart{
503 Type: fantasy.StreamPartTypeSource,
504 ID: uuid.NewString(),
505 SourceType: fantasy.SourceTypeURL,
506 URL: annotation.URLCitation.URL,
507 Title: annotation.URLCitation.Title,
508 }) {
509 return
510 }
511 }
512 }
513 }
514 }
515 }
516 err := stream.Err()
517 if err == nil || errors.Is(err, io.EOF) {
518 if isActiveText {
519 isActiveText = false
520 if !yield(fantasy.StreamPart{
521 Type: fantasy.StreamPartTypeTextEnd,
522 ID: "0",
523 }) {
524 return
525 }
526 }
527
528 if len(acc.Choices) > 0 {
529 choice := acc.Choices[0]
530 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
531
532 for _, annotation := range choice.Message.Annotations {
533 if annotation.Type == "url_citation" {
534 if !yield(fantasy.StreamPart{
535 Type: fantasy.StreamPartTypeSource,
536 ID: acc.ID,
537 SourceType: fantasy.SourceTypeURL,
538 URL: annotation.URLCitation.URL,
539 Title: annotation.URLCitation.Title,
540 }) {
541 return
542 }
543 }
544 }
545 }
546 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
547 if len(acc.Choices) > 0 {
548 choice := acc.Choices[0]
549 if len(choice.Message.ToolCalls) > 0 {
550 mappedFinishReason = fantasy.FinishReasonToolCalls
551 }
552 }
553 yield(fantasy.StreamPart{
554 Type: fantasy.StreamPartTypeFinish,
555 Usage: usage,
556 FinishReason: mappedFinishReason,
557 ProviderMetadata: providerMetadata,
558 })
559 return
560 } else { //nolint: revive
561 yield(fantasy.StreamPart{
562 Type: fantasy.StreamPartTypeError,
563 Error: toProviderErr(err),
564 })
565 return
566 }
567 }, nil
568}
569
570func isReasoningModel(modelID string) bool {
571 return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
572 strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573 strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
574 strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
575 strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
576}
577
578func isSearchPreviewModel(modelID string) bool {
579 return strings.Contains(modelID, "search-preview")
580}
581
582func supportsFlexProcessing(modelID string) bool {
583 return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
584 strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
585}
586
587func supportsPriorityProcessing(modelID string) bool {
588 return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
589 strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
590 strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
591}
592
593func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
594 for _, tool := range tools {
595 if tool.GetType() == fantasy.ToolTypeFunction {
596 ft, ok := tool.(fantasy.FunctionTool)
597 if !ok {
598 continue
599 }
600 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
601 OfFunction: &openai.ChatCompletionFunctionToolParam{
602 Function: shared.FunctionDefinitionParam{
603 Name: ft.Name,
604 Description: param.NewOpt(ft.Description),
605 Parameters: openai.FunctionParameters(ft.InputSchema),
606 Strict: param.NewOpt(false),
607 },
608 Type: "function",
609 },
610 })
611 continue
612 }
613
614 warnings = append(warnings, fantasy.CallWarning{
615 Type: fantasy.CallWarningTypeUnsupportedTool,
616 Tool: tool,
617 Message: "tool is not supported",
618 })
619 }
620 if toolChoice == nil {
621 return openAiTools, openAiToolChoice, warnings
622 }
623
624 switch *toolChoice {
625 case fantasy.ToolChoiceAuto:
626 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
627 OfAuto: param.NewOpt("auto"),
628 }
629 case fantasy.ToolChoiceNone:
630 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
631 OfAuto: param.NewOpt("none"),
632 }
633 case fantasy.ToolChoiceRequired:
634 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
635 OfAuto: param.NewOpt("required"),
636 }
637 default:
638 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
639 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
640 Type: "function",
641 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
642 Name: string(*toolChoice),
643 },
644 },
645 }
646 }
647 return openAiTools, openAiToolChoice, warnings
648}
649
650// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
651func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
652 var annotations []openai.ChatCompletionMessageAnnotation
653
654 // Parse the raw JSON to extract annotations
655 var deltaData map[string]any
656 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
657 return annotations
658 }
659
660 // Check if annotations exist in the delta
661 if annotationsData, ok := deltaData["annotations"].([]any); ok {
662 for _, annotationData := range annotationsData {
663 if annotationMap, ok := annotationData.(map[string]any); ok {
664 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
665 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
666 url, urlOk := urlCitationData["url"].(string)
667 title, titleOk := urlCitationData["title"].(string)
668 if urlOk && titleOk {
669 annotation := openai.ChatCompletionMessageAnnotation{
670 Type: "url_citation",
671 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
672 URL: url,
673 Title: title,
674 },
675 }
676 annotations = append(annotations, annotation)
677 }
678 }
679 }
680 }
681 }
682 }
683
684 return annotations
685}
686
687// GenerateObject implements fantasy.LanguageModel.
688func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
689 switch o.objectMode {
690 case fantasy.ObjectModeText:
691 return object.GenerateWithText(ctx, o, call)
692 case fantasy.ObjectModeTool:
693 return object.GenerateWithTool(ctx, o, call)
694 default:
695 return o.generateObjectWithJSONMode(ctx, call)
696 }
697}
698
699// StreamObject implements fantasy.LanguageModel.
700func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
701 switch o.objectMode {
702 case fantasy.ObjectModeTool:
703 return object.StreamWithTool(ctx, o, call)
704 case fantasy.ObjectModeText:
705 return object.StreamWithText(ctx, o, call)
706 default:
707 return o.streamObjectWithJSONMode(ctx, call)
708 }
709}
710
711func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
712 jsonSchemaMap := schema.ToMap(call.Schema)
713
714 addAdditionalPropertiesFalse(jsonSchemaMap)
715
716 schemaName := call.SchemaName
717 if schemaName == "" {
718 schemaName = "response"
719 }
720
721 fantasyCall := fantasy.Call{
722 Prompt: call.Prompt,
723 MaxOutputTokens: call.MaxOutputTokens,
724 Temperature: call.Temperature,
725 TopP: call.TopP,
726 PresencePenalty: call.PresencePenalty,
727 FrequencyPenalty: call.FrequencyPenalty,
728 ProviderOptions: call.ProviderOptions,
729 }
730
731 params, warnings, err := o.prepareParams(fantasyCall)
732 if err != nil {
733 return nil, err
734 }
735
736 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
737 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
738 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
739 Name: schemaName,
740 Description: param.NewOpt(call.SchemaDescription),
741 Schema: jsonSchemaMap,
742 Strict: param.NewOpt(true),
743 },
744 },
745 }
746
747 response, err := o.client.Chat.Completions.New(ctx, *params, objectCallUARequestOptions(call, o.noDefaultUserAgent)...)
748 if err != nil {
749 return nil, toProviderErr(err)
750 }
751 if len(response.Choices) == 0 {
752 usage, _ := o.usageFunc(*response)
753 return nil, &fantasy.NoObjectGeneratedError{
754 RawText: "",
755 ParseError: fmt.Errorf("no choices in response"),
756 Usage: usage,
757 FinishReason: fantasy.FinishReasonUnknown,
758 }
759 }
760
761 choice := response.Choices[0]
762 jsonText := choice.Message.Content
763
764 var obj any
765 if call.RepairText != nil {
766 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
767 } else {
768 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
769 }
770
771 usage, _ := o.usageFunc(*response)
772 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
773
774 if err != nil {
775 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
776 nogErr.Usage = usage
777 nogErr.FinishReason = finishReason
778 }
779 return nil, err
780 }
781
782 return &fantasy.ObjectResponse{
783 Object: obj,
784 RawText: jsonText,
785 Usage: usage,
786 FinishReason: finishReason,
787 Warnings: warnings,
788 }, nil
789}
790
791func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
792 jsonSchemaMap := schema.ToMap(call.Schema)
793
794 addAdditionalPropertiesFalse(jsonSchemaMap)
795
796 schemaName := call.SchemaName
797 if schemaName == "" {
798 schemaName = "response"
799 }
800
801 fantasyCall := fantasy.Call{
802 Prompt: call.Prompt,
803 MaxOutputTokens: call.MaxOutputTokens,
804 Temperature: call.Temperature,
805 TopP: call.TopP,
806 PresencePenalty: call.PresencePenalty,
807 FrequencyPenalty: call.FrequencyPenalty,
808 ProviderOptions: call.ProviderOptions,
809 }
810
811 params, warnings, err := o.prepareParams(fantasyCall)
812 if err != nil {
813 return nil, err
814 }
815
816 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
817 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
818 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
819 Name: schemaName,
820 Description: param.NewOpt(call.SchemaDescription),
821 Schema: jsonSchemaMap,
822 Strict: param.NewOpt(true),
823 },
824 },
825 }
826
827 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
828 IncludeUsage: openai.Bool(true),
829 }
830
831 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, objectCallUARequestOptions(call, o.noDefaultUserAgent)...)
832
833 return func(yield func(fantasy.ObjectStreamPart) bool) {
834 if len(warnings) > 0 {
835 if !yield(fantasy.ObjectStreamPart{
836 Type: fantasy.ObjectStreamPartTypeObject,
837 Warnings: warnings,
838 }) {
839 return
840 }
841 }
842
843 var accumulated string
844 var lastParsedObject any
845 var usage fantasy.Usage
846 var finishReason fantasy.FinishReason
847 var providerMetadata fantasy.ProviderMetadata
848 var streamErr error
849
850 for stream.Next() {
851 chunk := stream.Current()
852
853 // Update usage
854 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
855
856 if len(chunk.Choices) == 0 {
857 continue
858 }
859
860 choice := chunk.Choices[0]
861 if choice.FinishReason != "" {
862 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
863 }
864
865 if choice.Delta.Content != "" {
866 accumulated += choice.Delta.Content
867
868 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
869
870 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
871 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
872 if !reflect.DeepEqual(obj, lastParsedObject) {
873 if !yield(fantasy.ObjectStreamPart{
874 Type: fantasy.ObjectStreamPartTypeObject,
875 Object: obj,
876 }) {
877 return
878 }
879 lastParsedObject = obj
880 }
881 }
882 }
883
884 if state == schema.ParseStateFailed && call.RepairText != nil {
885 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
886 if repairErr == nil {
887 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
888 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
889 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
890 if !reflect.DeepEqual(obj2, lastParsedObject) {
891 if !yield(fantasy.ObjectStreamPart{
892 Type: fantasy.ObjectStreamPartTypeObject,
893 Object: obj2,
894 }) {
895 return
896 }
897 lastParsedObject = obj2
898 }
899 }
900 }
901 }
902 }
903 }
904
905 err := stream.Err()
906 if err != nil && !errors.Is(err, io.EOF) {
907 streamErr = toProviderErr(err)
908 yield(fantasy.ObjectStreamPart{
909 Type: fantasy.ObjectStreamPartTypeError,
910 Error: streamErr,
911 })
912 return
913 }
914
915 if lastParsedObject != nil {
916 yield(fantasy.ObjectStreamPart{
917 Type: fantasy.ObjectStreamPartTypeFinish,
918 Usage: usage,
919 FinishReason: finishReason,
920 ProviderMetadata: providerMetadata,
921 })
922 } else {
923 yield(fantasy.ObjectStreamPart{
924 Type: fantasy.ObjectStreamPartTypeError,
925 Error: &fantasy.NoObjectGeneratedError{
926 RawText: accumulated,
927 ParseError: fmt.Errorf("no valid object generated in stream"),
928 Usage: usage,
929 FinishReason: finishReason,
930 },
931 })
932 }
933 }, nil
934}
935
936// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
937// This is required by OpenAI's strict mode for structured outputs.
938func addAdditionalPropertiesFalse(schema map[string]any) {
939 if schema["type"] == "object" {
940 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
941 schema["additionalProperties"] = false
942 }
943
944 // Recursively process nested properties
945 if properties, ok := schema["properties"].(map[string]any); ok {
946 for _, propValue := range properties {
947 if propSchema, ok := propValue.(map[string]any); ok {
948 addAdditionalPropertiesFalse(propSchema)
949 }
950 }
951 }
952 }
953
954 // Handle array items
955 if items, ok := schema["items"].(map[string]any); ok {
956 addAdditionalPropertiesFalse(items)
957 }
958}