1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "strings"
11
12 "charm.land/fantasy"
13 "charm.land/fantasy/object"
14 "charm.land/fantasy/schema"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type languageModel struct {
23 provider string
24 modelID string
25 client openai.Client
26 objectMode fantasy.ObjectMode
27 prepareCallFunc LanguageModelPrepareCallFunc
28 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
29 extraContentFunc LanguageModelExtraContentFunc
30 usageFunc LanguageModelUsageFunc
31 streamUsageFunc LanguageModelStreamUsageFunc
32 streamExtraFunc LanguageModelStreamExtraFunc
33 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
34 toPromptFunc LanguageModelToPromptFunc
35}
36
37// LanguageModelOption is a function that configures a languageModel.
38type LanguageModelOption = func(*languageModel)
39
40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
42 return func(l *languageModel) {
43 l.prepareCallFunc = fn
44 }
45}
46
47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
49 return func(l *languageModel) {
50 l.mapFinishReasonFunc = fn
51 }
52}
53
54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
56 return func(l *languageModel) {
57 l.extraContentFunc = fn
58 }
59}
60
61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
63 return func(l *languageModel) {
64 l.streamExtraFunc = fn
65 }
66}
67
68// WithLanguageModelUsageFunc sets the usage function for the language model.
69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
70 return func(l *languageModel) {
71 l.usageFunc = fn
72 }
73}
74
75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
77 return func(l *languageModel) {
78 l.streamUsageFunc = fn
79 }
80}
81
82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
84 return func(l *languageModel) {
85 l.toPromptFunc = fn
86 }
87}
88
89// WithLanguageModelObjectMode sets the object generation mode.
90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
91 return func(l *languageModel) {
92 // not supported
93 if om == fantasy.ObjectModeJSON {
94 om = fantasy.ObjectModeAuto
95 }
96 l.objectMode = om
97 }
98}
99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101 model := languageModel{
102 modelID: modelID,
103 provider: provider,
104 client: client,
105 objectMode: fantasy.ObjectModeAuto,
106 prepareCallFunc: DefaultPrepareCallFunc,
107 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
108 usageFunc: DefaultUsageFunc,
109 streamUsageFunc: DefaultStreamUsageFunc,
110 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111 toPromptFunc: DefaultToPrompt,
112 }
113
114 for _, o := range opts {
115 o(&model)
116 }
117 return model
118}
119
120type streamToolCall struct {
121 id string
122 name string
123 arguments string
124 hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129 return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134 return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138 params := &openai.ChatCompletionNewParams{}
139 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140 if call.TopK != nil {
141 warnings = append(warnings, fantasy.CallWarning{
142 Type: fantasy.CallWarningTypeUnsupportedSetting,
143 Setting: "top_k",
144 })
145 }
146
147 if call.MaxOutputTokens != nil {
148 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149 }
150 if call.Temperature != nil {
151 params.Temperature = param.NewOpt(*call.Temperature)
152 }
153 if call.TopP != nil {
154 params.TopP = param.NewOpt(*call.TopP)
155 }
156 if call.FrequencyPenalty != nil {
157 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158 }
159 if call.PresencePenalty != nil {
160 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161 }
162
163 if isReasoningModel(o.modelID) {
164 // remove unsupported settings for reasoning models
165 // see https://platform.openai.com/docs/guides/reasoning#limitations
166 if call.Temperature != nil {
167 params.Temperature = param.Opt[float64]{}
168 warnings = append(warnings, fantasy.CallWarning{
169 Type: fantasy.CallWarningTypeUnsupportedSetting,
170 Setting: "temperature",
171 Details: "temperature is not supported for reasoning models",
172 })
173 }
174 if call.TopP != nil {
175 params.TopP = param.Opt[float64]{}
176 warnings = append(warnings, fantasy.CallWarning{
177 Type: fantasy.CallWarningTypeUnsupportedSetting,
178 Setting: "TopP",
179 Details: "TopP is not supported for reasoning models",
180 })
181 }
182 if call.FrequencyPenalty != nil {
183 params.FrequencyPenalty = param.Opt[float64]{}
184 warnings = append(warnings, fantasy.CallWarning{
185 Type: fantasy.CallWarningTypeUnsupportedSetting,
186 Setting: "FrequencyPenalty",
187 Details: "FrequencyPenalty is not supported for reasoning models",
188 })
189 }
190 if call.PresencePenalty != nil {
191 params.PresencePenalty = param.Opt[float64]{}
192 warnings = append(warnings, fantasy.CallWarning{
193 Type: fantasy.CallWarningTypeUnsupportedSetting,
194 Setting: "PresencePenalty",
195 Details: "PresencePenalty is not supported for reasoning models",
196 })
197 }
198
199 // reasoning models use max_completion_tokens instead of max_tokens
200 if call.MaxOutputTokens != nil {
201 if !params.MaxCompletionTokens.Valid() {
202 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203 }
204 params.MaxTokens = param.Opt[int64]{}
205 }
206 }
207
208 // Handle search preview models
209 if isSearchPreviewModel(o.modelID) {
210 if call.Temperature != nil {
211 params.Temperature = param.Opt[float64]{}
212 warnings = append(warnings, fantasy.CallWarning{
213 Type: fantasy.CallWarningTypeUnsupportedSetting,
214 Setting: "temperature",
215 Details: "temperature is not supported for the search preview models and has been removed.",
216 })
217 }
218 }
219
220 optionsWarnings, err := o.prepareCallFunc(o, params, call)
221 if err != nil {
222 return nil, nil, err
223 }
224
225 if len(optionsWarnings) > 0 {
226 warnings = append(warnings, optionsWarnings...)
227 }
228
229 params.Messages = messages
230 params.Model = o.modelID
231
232 if len(call.Tools) > 0 {
233 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234 params.Tools = tools
235 if toolChoice != nil {
236 params.ToolChoice = *toolChoice
237 }
238 warnings = append(warnings, toolWarnings...)
239 }
240 return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245 params, warnings, err := o.prepareParams(call)
246 if err != nil {
247 return nil, err
248 }
249 response, err := o.client.Chat.Completions.New(ctx, *params)
250 if err != nil {
251 return nil, toProviderErr(err)
252 }
253
254 if len(response.Choices) == 0 {
255 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256 }
257 choice := response.Choices[0]
258 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259 text := choice.Message.Content
260 if text != "" {
261 content = append(content, fantasy.TextContent{
262 Text: text,
263 })
264 }
265 if o.extraContentFunc != nil {
266 extraContent := o.extraContentFunc(choice)
267 content = append(content, extraContent...)
268 }
269 for _, tc := range choice.Message.ToolCalls {
270 toolCallID := tc.ID
271 content = append(content, fantasy.ToolCallContent{
272 ProviderExecuted: false,
273 ToolCallID: toolCallID,
274 ToolName: tc.Function.Name,
275 Input: tc.Function.Arguments,
276 })
277 }
278 for _, annotation := range choice.Message.Annotations {
279 if annotation.Type == "url_citation" {
280 content = append(content, fantasy.SourceContent{
281 SourceType: fantasy.SourceTypeURL,
282 ID: uuid.NewString(),
283 URL: annotation.URLCitation.URL,
284 Title: annotation.URLCitation.Title,
285 })
286 }
287 }
288
289 usage, providerMetadata := o.usageFunc(*response)
290
291 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292 if len(choice.Message.ToolCalls) > 0 {
293 mappedFinishReason = fantasy.FinishReasonToolCalls
294 }
295 return &fantasy.Response{
296 Content: content,
297 Usage: usage,
298 FinishReason: mappedFinishReason,
299 ProviderMetadata: fantasy.ProviderMetadata{
300 Name: providerMetadata,
301 },
302 Warnings: warnings,
303 }, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308 params, warnings, err := o.prepareParams(call)
309 if err != nil {
310 return nil, err
311 }
312
313 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314 IncludeUsage: openai.Bool(true),
315 }
316
317 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318 isActiveText := false
319 toolCalls := make(map[int64]streamToolCall)
320
321 providerMetadata := fantasy.ProviderMetadata{
322 Name: &ProviderMetadata{},
323 }
324 acc := openai.ChatCompletionAccumulator{}
325 extraContext := make(map[string]any)
326 var usage fantasy.Usage
327 var finishReason string
328 return func(yield func(fantasy.StreamPart) bool) {
329 if len(warnings) > 0 {
330 if !yield(fantasy.StreamPart{
331 Type: fantasy.StreamPartTypeWarnings,
332 Warnings: warnings,
333 }) {
334 return
335 }
336 }
337 for stream.Next() {
338 chunk := stream.Current()
339 acc.AddChunk(chunk)
340 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341 if len(chunk.Choices) == 0 {
342 continue
343 }
344 for _, choice := range chunk.Choices {
345 if choice.FinishReason != "" {
346 finishReason = choice.FinishReason
347 }
348 switch {
349 case choice.Delta.Content != "":
350 if !isActiveText {
351 isActiveText = true
352 if !yield(fantasy.StreamPart{
353 Type: fantasy.StreamPartTypeTextStart,
354 ID: "0",
355 }) {
356 return
357 }
358 }
359 if !yield(fantasy.StreamPart{
360 Type: fantasy.StreamPartTypeTextDelta,
361 ID: "0",
362 Delta: choice.Delta.Content,
363 }) {
364 return
365 }
366 case len(choice.Delta.ToolCalls) > 0:
367 if isActiveText {
368 isActiveText = false
369 if !yield(fantasy.StreamPart{
370 Type: fantasy.StreamPartTypeTextEnd,
371 ID: "0",
372 }) {
373 return
374 }
375 }
376
377 for _, toolCallDelta := range choice.Delta.ToolCalls {
378 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379 if existingToolCall.hasFinished {
380 continue
381 }
382 if toolCallDelta.Function.Arguments != "" {
383 existingToolCall.arguments += toolCallDelta.Function.Arguments
384 }
385 if !yield(fantasy.StreamPart{
386 Type: fantasy.StreamPartTypeToolInputDelta,
387 ID: existingToolCall.id,
388 Delta: toolCallDelta.Function.Arguments,
389 }) {
390 return
391 }
392 toolCalls[toolCallDelta.Index] = existingToolCall
393 if xjson.IsValid(existingToolCall.arguments) {
394 if !yield(fantasy.StreamPart{
395 Type: fantasy.StreamPartTypeToolInputEnd,
396 ID: existingToolCall.id,
397 }) {
398 return
399 }
400
401 if !yield(fantasy.StreamPart{
402 Type: fantasy.StreamPartTypeToolCall,
403 ID: existingToolCall.id,
404 ToolCallName: existingToolCall.name,
405 ToolCallInput: existingToolCall.arguments,
406 }) {
407 return
408 }
409 existingToolCall.hasFinished = true
410 toolCalls[toolCallDelta.Index] = existingToolCall
411 }
412 } else {
413 var err error
414 if toolCallDelta.Type != "function" {
415 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416 }
417 if toolCallDelta.ID == "" {
418 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419 }
420 if toolCallDelta.Function.Name == "" {
421 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422 }
423 if err != nil {
424 yield(fantasy.StreamPart{
425 Type: fantasy.StreamPartTypeError,
426 Error: toProviderErr(stream.Err()),
427 })
428 return
429 }
430
431 if !yield(fantasy.StreamPart{
432 Type: fantasy.StreamPartTypeToolInputStart,
433 ID: toolCallDelta.ID,
434 ToolCallName: toolCallDelta.Function.Name,
435 }) {
436 return
437 }
438 toolCalls[toolCallDelta.Index] = streamToolCall{
439 id: toolCallDelta.ID,
440 name: toolCallDelta.Function.Name,
441 arguments: toolCallDelta.Function.Arguments,
442 }
443
444 exTc := toolCalls[toolCallDelta.Index]
445 if exTc.arguments != "" {
446 if !yield(fantasy.StreamPart{
447 Type: fantasy.StreamPartTypeToolInputDelta,
448 ID: exTc.id,
449 Delta: exTc.arguments,
450 }) {
451 return
452 }
453 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454 if !yield(fantasy.StreamPart{
455 Type: fantasy.StreamPartTypeToolInputEnd,
456 ID: toolCallDelta.ID,
457 }) {
458 return
459 }
460
461 if !yield(fantasy.StreamPart{
462 Type: fantasy.StreamPartTypeToolCall,
463 ID: exTc.id,
464 ToolCallName: exTc.name,
465 ToolCallInput: exTc.arguments,
466 }) {
467 return
468 }
469 exTc.hasFinished = true
470 toolCalls[toolCallDelta.Index] = exTc
471 }
472 }
473 continue
474 }
475 }
476 }
477
478 if o.streamExtraFunc != nil {
479 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480 if !shouldContinue {
481 return
482 }
483 extraContext = updatedContext
484 }
485 }
486
487 for _, choice := range chunk.Choices {
488 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489 for _, annotation := range annotations {
490 if annotation.Type == "url_citation" {
491 if !yield(fantasy.StreamPart{
492 Type: fantasy.StreamPartTypeSource,
493 ID: uuid.NewString(),
494 SourceType: fantasy.SourceTypeURL,
495 URL: annotation.URLCitation.URL,
496 Title: annotation.URLCitation.Title,
497 }) {
498 return
499 }
500 }
501 }
502 }
503 }
504 }
505 err := stream.Err()
506 if err == nil || errors.Is(err, io.EOF) {
507 if isActiveText {
508 isActiveText = false
509 if !yield(fantasy.StreamPart{
510 Type: fantasy.StreamPartTypeTextEnd,
511 ID: "0",
512 }) {
513 return
514 }
515 }
516
517 if len(acc.Choices) > 0 {
518 choice := acc.Choices[0]
519 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521 for _, annotation := range choice.Message.Annotations {
522 if annotation.Type == "url_citation" {
523 if !yield(fantasy.StreamPart{
524 Type: fantasy.StreamPartTypeSource,
525 ID: acc.ID,
526 SourceType: fantasy.SourceTypeURL,
527 URL: annotation.URLCitation.URL,
528 Title: annotation.URLCitation.Title,
529 }) {
530 return
531 }
532 }
533 }
534 }
535 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536 if len(acc.Choices) > 0 {
537 choice := acc.Choices[0]
538 if len(choice.Message.ToolCalls) > 0 {
539 mappedFinishReason = fantasy.FinishReasonToolCalls
540 }
541 }
542 yield(fantasy.StreamPart{
543 Type: fantasy.StreamPartTypeFinish,
544 Usage: usage,
545 FinishReason: mappedFinishReason,
546 ProviderMetadata: providerMetadata,
547 })
548 return
549 } else { //nolint: revive
550 yield(fantasy.StreamPart{
551 Type: fantasy.StreamPartTypeError,
552 Error: toProviderErr(err),
553 })
554 return
555 }
556 }, nil
557}
558
559func isReasoningModel(modelID string) bool {
560 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
561}
562
563func isSearchPreviewModel(modelID string) bool {
564 return strings.Contains(modelID, "search-preview")
565}
566
567func supportsFlexProcessing(modelID string) bool {
568 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
569}
570
571func supportsPriorityProcessing(modelID string) bool {
572 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
573 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
574 strings.HasPrefix(modelID, "o4-mini")
575}
576
577func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
578 for _, tool := range tools {
579 if tool.GetType() == fantasy.ToolTypeFunction {
580 ft, ok := tool.(fantasy.FunctionTool)
581 if !ok {
582 continue
583 }
584 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
585 OfFunction: &openai.ChatCompletionFunctionToolParam{
586 Function: shared.FunctionDefinitionParam{
587 Name: ft.Name,
588 Description: param.NewOpt(ft.Description),
589 Parameters: openai.FunctionParameters(ft.InputSchema),
590 Strict: param.NewOpt(false),
591 },
592 Type: "function",
593 },
594 })
595 continue
596 }
597
598 warnings = append(warnings, fantasy.CallWarning{
599 Type: fantasy.CallWarningTypeUnsupportedTool,
600 Tool: tool,
601 Message: "tool is not supported",
602 })
603 }
604 if toolChoice == nil {
605 return openAiTools, openAiToolChoice, warnings
606 }
607
608 switch *toolChoice {
609 case fantasy.ToolChoiceAuto:
610 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
611 OfAuto: param.NewOpt("auto"),
612 }
613 case fantasy.ToolChoiceNone:
614 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
615 OfAuto: param.NewOpt("none"),
616 }
617 default:
618 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
619 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
620 Type: "function",
621 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
622 Name: string(*toolChoice),
623 },
624 },
625 }
626 }
627 return openAiTools, openAiToolChoice, warnings
628}
629
630// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
631func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
632 var annotations []openai.ChatCompletionMessageAnnotation
633
634 // Parse the raw JSON to extract annotations
635 var deltaData map[string]any
636 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
637 return annotations
638 }
639
640 // Check if annotations exist in the delta
641 if annotationsData, ok := deltaData["annotations"].([]any); ok {
642 for _, annotationData := range annotationsData {
643 if annotationMap, ok := annotationData.(map[string]any); ok {
644 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
645 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
646 annotation := openai.ChatCompletionMessageAnnotation{
647 Type: "url_citation",
648 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
649 URL: urlCitationData["url"].(string),
650 Title: urlCitationData["title"].(string),
651 },
652 }
653 annotations = append(annotations, annotation)
654 }
655 }
656 }
657 }
658 }
659
660 return annotations
661}
662
663// GenerateObject implements fantasy.LanguageModel.
664func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
665 switch o.objectMode {
666 case fantasy.ObjectModeText:
667 return object.GenerateWithText(ctx, o, call)
668 case fantasy.ObjectModeTool:
669 return object.GenerateWithTool(ctx, o, call)
670 default:
671 return o.generateObjectWithJSONMode(ctx, call)
672 }
673}
674
675// StreamObject implements fantasy.LanguageModel.
676func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
677 switch o.objectMode {
678 case fantasy.ObjectModeTool:
679 return object.StreamWithTool(ctx, o, call)
680 case fantasy.ObjectModeText:
681 return object.StreamWithText(ctx, o, call)
682 default:
683 return o.streamObjectWithJSONMode(ctx, call)
684 }
685}
686
687func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
688 jsonSchemaMap := schema.ToMap(call.Schema)
689
690 addAdditionalPropertiesFalse(jsonSchemaMap)
691
692 schemaName := call.SchemaName
693 if schemaName == "" {
694 schemaName = "response"
695 }
696
697 fantasyCall := fantasy.Call{
698 Prompt: call.Prompt,
699 MaxOutputTokens: call.MaxOutputTokens,
700 Temperature: call.Temperature,
701 TopP: call.TopP,
702 PresencePenalty: call.PresencePenalty,
703 FrequencyPenalty: call.FrequencyPenalty,
704 ProviderOptions: call.ProviderOptions,
705 }
706
707 params, warnings, err := o.prepareParams(fantasyCall)
708 if err != nil {
709 return nil, err
710 }
711
712 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
713 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
714 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
715 Name: schemaName,
716 Description: param.NewOpt(call.SchemaDescription),
717 Schema: jsonSchemaMap,
718 Strict: param.NewOpt(true),
719 },
720 },
721 }
722
723 response, err := o.client.Chat.Completions.New(ctx, *params)
724 if err != nil {
725 return nil, toProviderErr(err)
726 }
727
728 if len(response.Choices) == 0 {
729 usage, _ := o.usageFunc(*response)
730 return nil, &fantasy.NoObjectGeneratedError{
731 RawText: "",
732 ParseError: fmt.Errorf("no choices in response"),
733 Usage: usage,
734 FinishReason: fantasy.FinishReasonUnknown,
735 }
736 }
737
738 choice := response.Choices[0]
739 jsonText := choice.Message.Content
740
741 var obj any
742 if call.RepairText != nil {
743 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
744 } else {
745 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
746 }
747
748 usage, _ := o.usageFunc(*response)
749 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
750
751 if err != nil {
752 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
753 nogErr.Usage = usage
754 nogErr.FinishReason = finishReason
755 }
756 return nil, err
757 }
758
759 return &fantasy.ObjectResponse{
760 Object: obj,
761 RawText: jsonText,
762 Usage: usage,
763 FinishReason: finishReason,
764 Warnings: warnings,
765 }, nil
766}
767
768func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
769 jsonSchemaMap := schema.ToMap(call.Schema)
770
771 addAdditionalPropertiesFalse(jsonSchemaMap)
772
773 schemaName := call.SchemaName
774 if schemaName == "" {
775 schemaName = "response"
776 }
777
778 fantasyCall := fantasy.Call{
779 Prompt: call.Prompt,
780 MaxOutputTokens: call.MaxOutputTokens,
781 Temperature: call.Temperature,
782 TopP: call.TopP,
783 PresencePenalty: call.PresencePenalty,
784 FrequencyPenalty: call.FrequencyPenalty,
785 ProviderOptions: call.ProviderOptions,
786 }
787
788 params, warnings, err := o.prepareParams(fantasyCall)
789 if err != nil {
790 return nil, err
791 }
792
793 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
794 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
795 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
796 Name: schemaName,
797 Description: param.NewOpt(call.SchemaDescription),
798 Schema: jsonSchemaMap,
799 Strict: param.NewOpt(true),
800 },
801 },
802 }
803
804 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
805 IncludeUsage: openai.Bool(true),
806 }
807
808 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
809
810 return func(yield func(fantasy.ObjectStreamPart) bool) {
811 if len(warnings) > 0 {
812 if !yield(fantasy.ObjectStreamPart{
813 Type: fantasy.ObjectStreamPartTypeObject,
814 Warnings: warnings,
815 }) {
816 return
817 }
818 }
819
820 var accumulated string
821 var lastParsedObject any
822 var usage fantasy.Usage
823 var finishReason fantasy.FinishReason
824 var providerMetadata fantasy.ProviderMetadata
825 var streamErr error
826
827 for stream.Next() {
828 chunk := stream.Current()
829
830 // Update usage
831 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
832
833 if len(chunk.Choices) == 0 {
834 continue
835 }
836
837 choice := chunk.Choices[0]
838 if choice.FinishReason != "" {
839 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
840 }
841
842 if choice.Delta.Content != "" {
843 accumulated += choice.Delta.Content
844
845 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
846
847 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
848 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
849 if !reflect.DeepEqual(obj, lastParsedObject) {
850 if !yield(fantasy.ObjectStreamPart{
851 Type: fantasy.ObjectStreamPartTypeObject,
852 Object: obj,
853 }) {
854 return
855 }
856 lastParsedObject = obj
857 }
858 }
859 }
860
861 if state == schema.ParseStateFailed && call.RepairText != nil {
862 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
863 if repairErr == nil {
864 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
865 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
866 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
867 if !reflect.DeepEqual(obj2, lastParsedObject) {
868 if !yield(fantasy.ObjectStreamPart{
869 Type: fantasy.ObjectStreamPartTypeObject,
870 Object: obj2,
871 }) {
872 return
873 }
874 lastParsedObject = obj2
875 }
876 }
877 }
878 }
879 }
880 }
881
882 err := stream.Err()
883 if err != nil && !errors.Is(err, io.EOF) {
884 streamErr = toProviderErr(err)
885 yield(fantasy.ObjectStreamPart{
886 Type: fantasy.ObjectStreamPartTypeError,
887 Error: streamErr,
888 })
889 return
890 }
891
892 if lastParsedObject != nil {
893 yield(fantasy.ObjectStreamPart{
894 Type: fantasy.ObjectStreamPartTypeFinish,
895 Usage: usage,
896 FinishReason: finishReason,
897 ProviderMetadata: providerMetadata,
898 })
899 } else {
900 yield(fantasy.ObjectStreamPart{
901 Type: fantasy.ObjectStreamPartTypeError,
902 Error: &fantasy.NoObjectGeneratedError{
903 RawText: accumulated,
904 ParseError: fmt.Errorf("no valid object generated in stream"),
905 Usage: usage,
906 FinishReason: finishReason,
907 },
908 })
909 }
910 }, nil
911}
912
913// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
914// This is required by OpenAI's strict mode for structured outputs.
915func addAdditionalPropertiesFalse(schema map[string]any) {
916 if schema["type"] == "object" {
917 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
918 schema["additionalProperties"] = false
919 }
920
921 // Recursively process nested properties
922 if properties, ok := schema["properties"].(map[string]any); ok {
923 for _, propValue := range properties {
924 if propSchema, ok := propValue.(map[string]any); ok {
925 addAdditionalPropertiesFalse(propSchema)
926 }
927 }
928 }
929 }
930
931 // Handle array items
932 if items, ok := schema["items"].(map[string]any); ok {
933 addAdditionalPropertiesFalse(items)
934 }
935}