1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "strings"
11
12 "charm.land/fantasy"
13 "charm.land/fantasy/object"
14 "charm.land/fantasy/schema"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type languageModel struct {
23 provider string
24 modelID string
25 client openai.Client
26 objectMode fantasy.ObjectMode
27 prepareCallFunc LanguageModelPrepareCallFunc
28 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
29 extraContentFunc LanguageModelExtraContentFunc
30 usageFunc LanguageModelUsageFunc
31 streamUsageFunc LanguageModelStreamUsageFunc
32 streamExtraFunc LanguageModelStreamExtraFunc
33 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
34 toPromptFunc LanguageModelToPromptFunc
35}
36
37// LanguageModelOption is a function that configures a languageModel.
38type LanguageModelOption = func(*languageModel)
39
40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
42 return func(l *languageModel) {
43 l.prepareCallFunc = fn
44 }
45}
46
47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
49 return func(l *languageModel) {
50 l.mapFinishReasonFunc = fn
51 }
52}
53
54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
56 return func(l *languageModel) {
57 l.extraContentFunc = fn
58 }
59}
60
61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
63 return func(l *languageModel) {
64 l.streamExtraFunc = fn
65 }
66}
67
68// WithLanguageModelUsageFunc sets the usage function for the language model.
69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
70 return func(l *languageModel) {
71 l.usageFunc = fn
72 }
73}
74
75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
77 return func(l *languageModel) {
78 l.streamUsageFunc = fn
79 }
80}
81
82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
84 return func(l *languageModel) {
85 l.toPromptFunc = fn
86 }
87}
88
89// WithLanguageModelObjectMode sets the object generation mode.
90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
91 return func(l *languageModel) {
92 // not supported
93 if om == fantasy.ObjectModeJSON {
94 om = fantasy.ObjectModeAuto
95 }
96 l.objectMode = om
97 }
98}
99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101 model := languageModel{
102 modelID: modelID,
103 provider: provider,
104 client: client,
105 objectMode: fantasy.ObjectModeAuto,
106 prepareCallFunc: DefaultPrepareCallFunc,
107 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
108 usageFunc: DefaultUsageFunc,
109 streamUsageFunc: DefaultStreamUsageFunc,
110 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111 toPromptFunc: DefaultToPrompt,
112 }
113
114 for _, o := range opts {
115 o(&model)
116 }
117 return model
118}
119
120type streamToolCall struct {
121 id string
122 name string
123 arguments string
124 hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129 return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134 return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138 params := &openai.ChatCompletionNewParams{}
139 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140 if call.TopK != nil {
141 warnings = append(warnings, fantasy.CallWarning{
142 Type: fantasy.CallWarningTypeUnsupportedSetting,
143 Setting: "top_k",
144 })
145 }
146
147 if call.MaxOutputTokens != nil {
148 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149 }
150 if call.Temperature != nil {
151 params.Temperature = param.NewOpt(*call.Temperature)
152 }
153 if call.TopP != nil {
154 params.TopP = param.NewOpt(*call.TopP)
155 }
156 if call.FrequencyPenalty != nil {
157 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158 }
159 if call.PresencePenalty != nil {
160 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161 }
162
163 if isReasoningModel(o.modelID) {
164 // remove unsupported settings for reasoning models
165 // see https://platform.openai.com/docs/guides/reasoning#limitations
166 if call.Temperature != nil {
167 params.Temperature = param.Opt[float64]{}
168 warnings = append(warnings, fantasy.CallWarning{
169 Type: fantasy.CallWarningTypeUnsupportedSetting,
170 Setting: "temperature",
171 Details: "temperature is not supported for reasoning models",
172 })
173 }
174 if call.TopP != nil {
175 params.TopP = param.Opt[float64]{}
176 warnings = append(warnings, fantasy.CallWarning{
177 Type: fantasy.CallWarningTypeUnsupportedSetting,
178 Setting: "TopP",
179 Details: "TopP is not supported for reasoning models",
180 })
181 }
182 if call.FrequencyPenalty != nil {
183 params.FrequencyPenalty = param.Opt[float64]{}
184 warnings = append(warnings, fantasy.CallWarning{
185 Type: fantasy.CallWarningTypeUnsupportedSetting,
186 Setting: "FrequencyPenalty",
187 Details: "FrequencyPenalty is not supported for reasoning models",
188 })
189 }
190 if call.PresencePenalty != nil {
191 params.PresencePenalty = param.Opt[float64]{}
192 warnings = append(warnings, fantasy.CallWarning{
193 Type: fantasy.CallWarningTypeUnsupportedSetting,
194 Setting: "PresencePenalty",
195 Details: "PresencePenalty is not supported for reasoning models",
196 })
197 }
198
199 // reasoning models use max_completion_tokens instead of max_tokens
200 if call.MaxOutputTokens != nil {
201 if !params.MaxCompletionTokens.Valid() {
202 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203 }
204 params.MaxTokens = param.Opt[int64]{}
205 }
206 }
207
208 // Handle search preview models
209 if isSearchPreviewModel(o.modelID) {
210 if call.Temperature != nil {
211 params.Temperature = param.Opt[float64]{}
212 warnings = append(warnings, fantasy.CallWarning{
213 Type: fantasy.CallWarningTypeUnsupportedSetting,
214 Setting: "temperature",
215 Details: "temperature is not supported for the search preview models and has been removed.",
216 })
217 }
218 }
219
220 optionsWarnings, err := o.prepareCallFunc(o, params, call)
221 if err != nil {
222 return nil, nil, err
223 }
224
225 if len(optionsWarnings) > 0 {
226 warnings = append(warnings, optionsWarnings...)
227 }
228
229 params.Messages = messages
230 params.Model = o.modelID
231
232 if len(call.Tools) > 0 {
233 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234 params.Tools = tools
235 if toolChoice != nil {
236 params.ToolChoice = *toolChoice
237 }
238 warnings = append(warnings, toolWarnings...)
239 }
240 return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245 params, warnings, err := o.prepareParams(call)
246 if err != nil {
247 return nil, err
248 }
249 response, err := o.client.Chat.Completions.New(ctx, *params)
250 if err != nil {
251 return nil, toProviderErr(err)
252 }
253
254 if len(response.Choices) == 0 {
255 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256 }
257 choice := response.Choices[0]
258 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259 text := choice.Message.Content
260 if text != "" {
261 content = append(content, fantasy.TextContent{
262 Text: text,
263 })
264 }
265 if o.extraContentFunc != nil {
266 extraContent := o.extraContentFunc(choice)
267 content = append(content, extraContent...)
268 }
269 for _, tc := range choice.Message.ToolCalls {
270 toolCallID := tc.ID
271 content = append(content, fantasy.ToolCallContent{
272 ProviderExecuted: false,
273 ToolCallID: toolCallID,
274 ToolName: tc.Function.Name,
275 Input: tc.Function.Arguments,
276 })
277 }
278 for _, annotation := range choice.Message.Annotations {
279 if annotation.Type == "url_citation" {
280 content = append(content, fantasy.SourceContent{
281 SourceType: fantasy.SourceTypeURL,
282 ID: uuid.NewString(),
283 URL: annotation.URLCitation.URL,
284 Title: annotation.URLCitation.Title,
285 })
286 }
287 }
288
289 usage, providerMetadata := o.usageFunc(*response)
290
291 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292 if len(choice.Message.ToolCalls) > 0 {
293 mappedFinishReason = fantasy.FinishReasonToolCalls
294 }
295 return &fantasy.Response{
296 Content: content,
297 Usage: usage,
298 FinishReason: mappedFinishReason,
299 ProviderMetadata: fantasy.ProviderMetadata{
300 Name: providerMetadata,
301 },
302 Warnings: warnings,
303 }, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308 params, warnings, err := o.prepareParams(call)
309 if err != nil {
310 return nil, err
311 }
312
313 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314 IncludeUsage: openai.Bool(true),
315 }
316
317 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318 isActiveText := false
319 toolCalls := make(map[int64]streamToolCall)
320
321 providerMetadata := fantasy.ProviderMetadata{
322 Name: &ProviderMetadata{},
323 }
324 acc := openai.ChatCompletionAccumulator{}
325 extraContext := make(map[string]any)
326 var usage fantasy.Usage
327 var finishReason string
328 return func(yield func(fantasy.StreamPart) bool) {
329 if len(warnings) > 0 {
330 if !yield(fantasy.StreamPart{
331 Type: fantasy.StreamPartTypeWarnings,
332 Warnings: warnings,
333 }) {
334 return
335 }
336 }
337 for stream.Next() {
338 chunk := stream.Current()
339 acc.AddChunk(chunk)
340 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341 if len(chunk.Choices) == 0 {
342 continue
343 }
344 for _, choice := range chunk.Choices {
345 if choice.FinishReason != "" {
346 finishReason = choice.FinishReason
347 }
348 switch {
349 case choice.Delta.Content != "":
350 if !isActiveText {
351 isActiveText = true
352 if !yield(fantasy.StreamPart{
353 Type: fantasy.StreamPartTypeTextStart,
354 ID: "0",
355 }) {
356 return
357 }
358 }
359 if !yield(fantasy.StreamPart{
360 Type: fantasy.StreamPartTypeTextDelta,
361 ID: "0",
362 Delta: choice.Delta.Content,
363 }) {
364 return
365 }
366 case len(choice.Delta.ToolCalls) > 0:
367 if isActiveText {
368 isActiveText = false
369 if !yield(fantasy.StreamPart{
370 Type: fantasy.StreamPartTypeTextEnd,
371 ID: "0",
372 }) {
373 return
374 }
375 }
376
377 for _, toolCallDelta := range choice.Delta.ToolCalls {
378 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379 if existingToolCall.hasFinished {
380 continue
381 }
382 if toolCallDelta.Function.Arguments != "" {
383 existingToolCall.arguments += toolCallDelta.Function.Arguments
384 }
385 if !yield(fantasy.StreamPart{
386 Type: fantasy.StreamPartTypeToolInputDelta,
387 ID: existingToolCall.id,
388 Delta: toolCallDelta.Function.Arguments,
389 }) {
390 return
391 }
392 toolCalls[toolCallDelta.Index] = existingToolCall
393 if xjson.IsValid(existingToolCall.arguments) {
394 if !yield(fantasy.StreamPart{
395 Type: fantasy.StreamPartTypeToolInputEnd,
396 ID: existingToolCall.id,
397 }) {
398 return
399 }
400
401 if !yield(fantasy.StreamPart{
402 Type: fantasy.StreamPartTypeToolCall,
403 ID: existingToolCall.id,
404 ToolCallName: existingToolCall.name,
405 ToolCallInput: existingToolCall.arguments,
406 }) {
407 return
408 }
409 existingToolCall.hasFinished = true
410 toolCalls[toolCallDelta.Index] = existingToolCall
411 }
412 } else {
413 var err error
414 if toolCallDelta.Type != "function" {
415 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416 }
417 if toolCallDelta.ID == "" {
418 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419 }
420 if toolCallDelta.Function.Name == "" {
421 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422 }
423 if err != nil {
424 yield(fantasy.StreamPart{
425 Type: fantasy.StreamPartTypeError,
426 Error: toProviderErr(stream.Err()),
427 })
428 return
429 }
430
431 if !yield(fantasy.StreamPart{
432 Type: fantasy.StreamPartTypeToolInputStart,
433 ID: toolCallDelta.ID,
434 ToolCallName: toolCallDelta.Function.Name,
435 }) {
436 return
437 }
438 toolCalls[toolCallDelta.Index] = streamToolCall{
439 id: toolCallDelta.ID,
440 name: toolCallDelta.Function.Name,
441 arguments: toolCallDelta.Function.Arguments,
442 }
443
444 exTc := toolCalls[toolCallDelta.Index]
445 if exTc.arguments != "" {
446 if !yield(fantasy.StreamPart{
447 Type: fantasy.StreamPartTypeToolInputDelta,
448 ID: exTc.id,
449 Delta: exTc.arguments,
450 }) {
451 return
452 }
453 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454 if !yield(fantasy.StreamPart{
455 Type: fantasy.StreamPartTypeToolInputEnd,
456 ID: toolCallDelta.ID,
457 }) {
458 return
459 }
460
461 if !yield(fantasy.StreamPart{
462 Type: fantasy.StreamPartTypeToolCall,
463 ID: exTc.id,
464 ToolCallName: exTc.name,
465 ToolCallInput: exTc.arguments,
466 }) {
467 return
468 }
469 exTc.hasFinished = true
470 toolCalls[toolCallDelta.Index] = exTc
471 }
472 }
473 continue
474 }
475 }
476 }
477
478 if o.streamExtraFunc != nil {
479 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480 if !shouldContinue {
481 return
482 }
483 extraContext = updatedContext
484 }
485 }
486
487 for _, choice := range chunk.Choices {
488 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489 for _, annotation := range annotations {
490 if annotation.Type == "url_citation" {
491 if !yield(fantasy.StreamPart{
492 Type: fantasy.StreamPartTypeSource,
493 ID: uuid.NewString(),
494 SourceType: fantasy.SourceTypeURL,
495 URL: annotation.URLCitation.URL,
496 Title: annotation.URLCitation.Title,
497 }) {
498 return
499 }
500 }
501 }
502 }
503 }
504 }
505 err := stream.Err()
506 if err == nil || errors.Is(err, io.EOF) {
507 if isActiveText {
508 isActiveText = false
509 if !yield(fantasy.StreamPart{
510 Type: fantasy.StreamPartTypeTextEnd,
511 ID: "0",
512 }) {
513 return
514 }
515 }
516
517 if len(acc.Choices) > 0 {
518 choice := acc.Choices[0]
519 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521 for _, annotation := range choice.Message.Annotations {
522 if annotation.Type == "url_citation" {
523 if !yield(fantasy.StreamPart{
524 Type: fantasy.StreamPartTypeSource,
525 ID: acc.ID,
526 SourceType: fantasy.SourceTypeURL,
527 URL: annotation.URLCitation.URL,
528 Title: annotation.URLCitation.Title,
529 }) {
530 return
531 }
532 }
533 }
534 }
535 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536 if len(acc.Choices) > 0 {
537 choice := acc.Choices[0]
538 if len(choice.Message.ToolCalls) > 0 {
539 mappedFinishReason = fantasy.FinishReasonToolCalls
540 }
541 }
542 yield(fantasy.StreamPart{
543 Type: fantasy.StreamPartTypeFinish,
544 Usage: usage,
545 FinishReason: mappedFinishReason,
546 ProviderMetadata: providerMetadata,
547 })
548 return
549 } else { //nolint: revive
550 yield(fantasy.StreamPart{
551 Type: fantasy.StreamPartTypeError,
552 Error: toProviderErr(err),
553 })
554 return
555 }
556 }, nil
557}
558
559func isReasoningModel(modelID string) bool {
560 return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561 strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562 strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563 strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564 strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568 return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572 return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573 strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577 return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578 strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579 strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583 for _, tool := range tools {
584 if tool.GetType() == fantasy.ToolTypeFunction {
585 ft, ok := tool.(fantasy.FunctionTool)
586 if !ok {
587 continue
588 }
589 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590 OfFunction: &openai.ChatCompletionFunctionToolParam{
591 Function: shared.FunctionDefinitionParam{
592 Name: ft.Name,
593 Description: param.NewOpt(ft.Description),
594 Parameters: openai.FunctionParameters(ft.InputSchema),
595 Strict: param.NewOpt(false),
596 },
597 Type: "function",
598 },
599 })
600 continue
601 }
602
603 warnings = append(warnings, fantasy.CallWarning{
604 Type: fantasy.CallWarningTypeUnsupportedTool,
605 Tool: tool,
606 Message: "tool is not supported",
607 })
608 }
609 if toolChoice == nil {
610 return openAiTools, openAiToolChoice, warnings
611 }
612
613 switch *toolChoice {
614 case fantasy.ToolChoiceAuto:
615 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616 OfAuto: param.NewOpt("auto"),
617 }
618 case fantasy.ToolChoiceNone:
619 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620 OfAuto: param.NewOpt("none"),
621 }
622 default:
623 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
625 Type: "function",
626 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
627 Name: string(*toolChoice),
628 },
629 },
630 }
631 }
632 return openAiTools, openAiToolChoice, warnings
633}
634
635// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
636func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
637 var annotations []openai.ChatCompletionMessageAnnotation
638
639 // Parse the raw JSON to extract annotations
640 var deltaData map[string]any
641 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
642 return annotations
643 }
644
645 // Check if annotations exist in the delta
646 if annotationsData, ok := deltaData["annotations"].([]any); ok {
647 for _, annotationData := range annotationsData {
648 if annotationMap, ok := annotationData.(map[string]any); ok {
649 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
650 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
651 annotation := openai.ChatCompletionMessageAnnotation{
652 Type: "url_citation",
653 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
654 URL: urlCitationData["url"].(string),
655 Title: urlCitationData["title"].(string),
656 },
657 }
658 annotations = append(annotations, annotation)
659 }
660 }
661 }
662 }
663 }
664
665 return annotations
666}
667
668// GenerateObject implements fantasy.LanguageModel.
669func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
670 switch o.objectMode {
671 case fantasy.ObjectModeText:
672 return object.GenerateWithText(ctx, o, call)
673 case fantasy.ObjectModeTool:
674 return object.GenerateWithTool(ctx, o, call)
675 default:
676 return o.generateObjectWithJSONMode(ctx, call)
677 }
678}
679
680// StreamObject implements fantasy.LanguageModel.
681func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
682 switch o.objectMode {
683 case fantasy.ObjectModeTool:
684 return object.StreamWithTool(ctx, o, call)
685 case fantasy.ObjectModeText:
686 return object.StreamWithText(ctx, o, call)
687 default:
688 return o.streamObjectWithJSONMode(ctx, call)
689 }
690}
691
692func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
693 jsonSchemaMap := schema.ToMap(call.Schema)
694
695 addAdditionalPropertiesFalse(jsonSchemaMap)
696
697 schemaName := call.SchemaName
698 if schemaName == "" {
699 schemaName = "response"
700 }
701
702 fantasyCall := fantasy.Call{
703 Prompt: call.Prompt,
704 MaxOutputTokens: call.MaxOutputTokens,
705 Temperature: call.Temperature,
706 TopP: call.TopP,
707 PresencePenalty: call.PresencePenalty,
708 FrequencyPenalty: call.FrequencyPenalty,
709 ProviderOptions: call.ProviderOptions,
710 }
711
712 params, warnings, err := o.prepareParams(fantasyCall)
713 if err != nil {
714 return nil, err
715 }
716
717 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
718 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
719 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
720 Name: schemaName,
721 Description: param.NewOpt(call.SchemaDescription),
722 Schema: jsonSchemaMap,
723 Strict: param.NewOpt(true),
724 },
725 },
726 }
727
728 response, err := o.client.Chat.Completions.New(ctx, *params)
729 if err != nil {
730 return nil, toProviderErr(err)
731 }
732
733 if len(response.Choices) == 0 {
734 usage, _ := o.usageFunc(*response)
735 return nil, &fantasy.NoObjectGeneratedError{
736 RawText: "",
737 ParseError: fmt.Errorf("no choices in response"),
738 Usage: usage,
739 FinishReason: fantasy.FinishReasonUnknown,
740 }
741 }
742
743 choice := response.Choices[0]
744 jsonText := choice.Message.Content
745
746 var obj any
747 if call.RepairText != nil {
748 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
749 } else {
750 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
751 }
752
753 usage, _ := o.usageFunc(*response)
754 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
755
756 if err != nil {
757 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
758 nogErr.Usage = usage
759 nogErr.FinishReason = finishReason
760 }
761 return nil, err
762 }
763
764 return &fantasy.ObjectResponse{
765 Object: obj,
766 RawText: jsonText,
767 Usage: usage,
768 FinishReason: finishReason,
769 Warnings: warnings,
770 }, nil
771}
772
773func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
774 jsonSchemaMap := schema.ToMap(call.Schema)
775
776 addAdditionalPropertiesFalse(jsonSchemaMap)
777
778 schemaName := call.SchemaName
779 if schemaName == "" {
780 schemaName = "response"
781 }
782
783 fantasyCall := fantasy.Call{
784 Prompt: call.Prompt,
785 MaxOutputTokens: call.MaxOutputTokens,
786 Temperature: call.Temperature,
787 TopP: call.TopP,
788 PresencePenalty: call.PresencePenalty,
789 FrequencyPenalty: call.FrequencyPenalty,
790 ProviderOptions: call.ProviderOptions,
791 }
792
793 params, warnings, err := o.prepareParams(fantasyCall)
794 if err != nil {
795 return nil, err
796 }
797
798 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
799 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
800 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
801 Name: schemaName,
802 Description: param.NewOpt(call.SchemaDescription),
803 Schema: jsonSchemaMap,
804 Strict: param.NewOpt(true),
805 },
806 },
807 }
808
809 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
810 IncludeUsage: openai.Bool(true),
811 }
812
813 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
814
815 return func(yield func(fantasy.ObjectStreamPart) bool) {
816 if len(warnings) > 0 {
817 if !yield(fantasy.ObjectStreamPart{
818 Type: fantasy.ObjectStreamPartTypeObject,
819 Warnings: warnings,
820 }) {
821 return
822 }
823 }
824
825 var accumulated string
826 var lastParsedObject any
827 var usage fantasy.Usage
828 var finishReason fantasy.FinishReason
829 var providerMetadata fantasy.ProviderMetadata
830 var streamErr error
831
832 for stream.Next() {
833 chunk := stream.Current()
834
835 // Update usage
836 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
837
838 if len(chunk.Choices) == 0 {
839 continue
840 }
841
842 choice := chunk.Choices[0]
843 if choice.FinishReason != "" {
844 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
845 }
846
847 if choice.Delta.Content != "" {
848 accumulated += choice.Delta.Content
849
850 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
851
852 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
853 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
854 if !reflect.DeepEqual(obj, lastParsedObject) {
855 if !yield(fantasy.ObjectStreamPart{
856 Type: fantasy.ObjectStreamPartTypeObject,
857 Object: obj,
858 }) {
859 return
860 }
861 lastParsedObject = obj
862 }
863 }
864 }
865
866 if state == schema.ParseStateFailed && call.RepairText != nil {
867 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
868 if repairErr == nil {
869 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
870 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
871 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
872 if !reflect.DeepEqual(obj2, lastParsedObject) {
873 if !yield(fantasy.ObjectStreamPart{
874 Type: fantasy.ObjectStreamPartTypeObject,
875 Object: obj2,
876 }) {
877 return
878 }
879 lastParsedObject = obj2
880 }
881 }
882 }
883 }
884 }
885 }
886
887 err := stream.Err()
888 if err != nil && !errors.Is(err, io.EOF) {
889 streamErr = toProviderErr(err)
890 yield(fantasy.ObjectStreamPart{
891 Type: fantasy.ObjectStreamPartTypeError,
892 Error: streamErr,
893 })
894 return
895 }
896
897 if lastParsedObject != nil {
898 yield(fantasy.ObjectStreamPart{
899 Type: fantasy.ObjectStreamPartTypeFinish,
900 Usage: usage,
901 FinishReason: finishReason,
902 ProviderMetadata: providerMetadata,
903 })
904 } else {
905 yield(fantasy.ObjectStreamPart{
906 Type: fantasy.ObjectStreamPartTypeError,
907 Error: &fantasy.NoObjectGeneratedError{
908 RawText: accumulated,
909 ParseError: fmt.Errorf("no valid object generated in stream"),
910 Usage: usage,
911 FinishReason: finishReason,
912 },
913 })
914 }
915 }, nil
916}
917
918// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
919// This is required by OpenAI's strict mode for structured outputs.
920func addAdditionalPropertiesFalse(schema map[string]any) {
921 if schema["type"] == "object" {
922 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
923 schema["additionalProperties"] = false
924 }
925
926 // Recursively process nested properties
927 if properties, ok := schema["properties"].(map[string]any); ok {
928 for _, propValue := range properties {
929 if propSchema, ok := propValue.(map[string]any); ok {
930 addAdditionalPropertiesFalse(propSchema)
931 }
932 }
933 }
934 }
935
936 // Handle array items
937 if items, ok := schema["items"].(map[string]any); ok {
938 addAdditionalPropertiesFalse(items)
939 }
940}