1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "strings"
11
12 "charm.land/fantasy"
13 "charm.land/fantasy/object"
14 "charm.land/fantasy/schema"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type languageModel struct {
23 provider string
24 modelID string
25 client openai.Client
26 objectMode fantasy.ObjectMode
27 prepareCallFunc LanguageModelPrepareCallFunc
28 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
29 extraContentFunc LanguageModelExtraContentFunc
30 usageFunc LanguageModelUsageFunc
31 streamUsageFunc LanguageModelStreamUsageFunc
32 streamExtraFunc LanguageModelStreamExtraFunc
33 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
34 toPromptFunc LanguageModelToPromptFunc
35}
36
37// LanguageModelOption is a function that configures a languageModel.
38type LanguageModelOption = func(*languageModel)
39
40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
42 return func(l *languageModel) {
43 l.prepareCallFunc = fn
44 }
45}
46
47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
49 return func(l *languageModel) {
50 l.mapFinishReasonFunc = fn
51 }
52}
53
54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
56 return func(l *languageModel) {
57 l.extraContentFunc = fn
58 }
59}
60
61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
63 return func(l *languageModel) {
64 l.streamExtraFunc = fn
65 }
66}
67
68// WithLanguageModelUsageFunc sets the usage function for the language model.
69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
70 return func(l *languageModel) {
71 l.usageFunc = fn
72 }
73}
74
75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
77 return func(l *languageModel) {
78 l.streamUsageFunc = fn
79 }
80}
81
82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
84 return func(l *languageModel) {
85 l.toPromptFunc = fn
86 }
87}
88
89// WithLanguageModelObjectMode sets the object generation mode.
90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
91 return func(l *languageModel) {
92 // not supported
93 if om == fantasy.ObjectModeJSON {
94 om = fantasy.ObjectModeAuto
95 }
96 l.objectMode = om
97 }
98}
99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101 model := languageModel{
102 modelID: modelID,
103 provider: provider,
104 client: client,
105 objectMode: fantasy.ObjectModeAuto,
106 prepareCallFunc: DefaultPrepareCallFunc,
107 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
108 usageFunc: DefaultUsageFunc,
109 streamUsageFunc: DefaultStreamUsageFunc,
110 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111 toPromptFunc: DefaultToPrompt,
112 }
113
114 for _, o := range opts {
115 o(&model)
116 }
117 return model
118}
119
120type streamToolCall struct {
121 id string
122 name string
123 arguments string
124 hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129 return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134 return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138 params := &openai.ChatCompletionNewParams{}
139 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140 if call.TopK != nil {
141 warnings = append(warnings, fantasy.CallWarning{
142 Type: fantasy.CallWarningTypeUnsupportedSetting,
143 Setting: "top_k",
144 })
145 }
146
147 if call.MaxOutputTokens != nil {
148 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149 }
150 if call.Temperature != nil {
151 params.Temperature = param.NewOpt(*call.Temperature)
152 }
153 if call.TopP != nil {
154 params.TopP = param.NewOpt(*call.TopP)
155 }
156 if call.FrequencyPenalty != nil {
157 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158 }
159 if call.PresencePenalty != nil {
160 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161 }
162
163 if isReasoningModel(o.modelID) {
164 // remove unsupported settings for reasoning models
165 // see https://platform.openai.com/docs/guides/reasoning#limitations
166 if call.Temperature != nil {
167 params.Temperature = param.Opt[float64]{}
168 warnings = append(warnings, fantasy.CallWarning{
169 Type: fantasy.CallWarningTypeUnsupportedSetting,
170 Setting: "temperature",
171 Details: "temperature is not supported for reasoning models",
172 })
173 }
174 if call.TopP != nil {
175 params.TopP = param.Opt[float64]{}
176 warnings = append(warnings, fantasy.CallWarning{
177 Type: fantasy.CallWarningTypeUnsupportedSetting,
178 Setting: "TopP",
179 Details: "TopP is not supported for reasoning models",
180 })
181 }
182 if call.FrequencyPenalty != nil {
183 params.FrequencyPenalty = param.Opt[float64]{}
184 warnings = append(warnings, fantasy.CallWarning{
185 Type: fantasy.CallWarningTypeUnsupportedSetting,
186 Setting: "FrequencyPenalty",
187 Details: "FrequencyPenalty is not supported for reasoning models",
188 })
189 }
190 if call.PresencePenalty != nil {
191 params.PresencePenalty = param.Opt[float64]{}
192 warnings = append(warnings, fantasy.CallWarning{
193 Type: fantasy.CallWarningTypeUnsupportedSetting,
194 Setting: "PresencePenalty",
195 Details: "PresencePenalty is not supported for reasoning models",
196 })
197 }
198
199 // reasoning models use max_completion_tokens instead of max_tokens
200 if call.MaxOutputTokens != nil {
201 if !params.MaxCompletionTokens.Valid() {
202 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203 }
204 params.MaxTokens = param.Opt[int64]{}
205 }
206 }
207
208 // Handle search preview models
209 if isSearchPreviewModel(o.modelID) {
210 if call.Temperature != nil {
211 params.Temperature = param.Opt[float64]{}
212 warnings = append(warnings, fantasy.CallWarning{
213 Type: fantasy.CallWarningTypeUnsupportedSetting,
214 Setting: "temperature",
215 Details: "temperature is not supported for the search preview models and has been removed.",
216 })
217 }
218 }
219
220 optionsWarnings, err := o.prepareCallFunc(o, params, call)
221 if err != nil {
222 return nil, nil, err
223 }
224
225 if len(optionsWarnings) > 0 {
226 warnings = append(warnings, optionsWarnings...)
227 }
228
229 params.Messages = messages
230 params.Model = o.modelID
231
232 if len(call.Tools) > 0 {
233 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234 params.Tools = tools
235 if toolChoice != nil {
236 params.ToolChoice = *toolChoice
237 }
238 warnings = append(warnings, toolWarnings...)
239 }
240 return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245 params, warnings, err := o.prepareParams(call)
246 if err != nil {
247 return nil, err
248 }
249 response, err := o.client.Chat.Completions.New(ctx, *params)
250 if err != nil {
251 return nil, toProviderErr(err)
252 }
253
254 if len(response.Choices) == 0 {
255 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256 }
257 choice := response.Choices[0]
258 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259 text := choice.Message.Content
260 if text != "" {
261 content = append(content, fantasy.TextContent{
262 Text: text,
263 })
264 }
265 if o.extraContentFunc != nil {
266 extraContent := o.extraContentFunc(choice)
267 content = append(content, extraContent...)
268 }
269 for _, tc := range choice.Message.ToolCalls {
270 toolCallID := tc.ID
271 content = append(content, fantasy.ToolCallContent{
272 ProviderExecuted: false,
273 ToolCallID: toolCallID,
274 ToolName: tc.Function.Name,
275 Input: tc.Function.Arguments,
276 })
277 }
278 for _, annotation := range choice.Message.Annotations {
279 if annotation.Type == "url_citation" {
280 content = append(content, fantasy.SourceContent{
281 SourceType: fantasy.SourceTypeURL,
282 ID: uuid.NewString(),
283 URL: annotation.URLCitation.URL,
284 Title: annotation.URLCitation.Title,
285 })
286 }
287 }
288
289 usage, providerMetadata := o.usageFunc(*response)
290
291 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292 if len(choice.Message.ToolCalls) > 0 {
293 mappedFinishReason = fantasy.FinishReasonToolCalls
294 }
295 return &fantasy.Response{
296 Content: content,
297 Usage: usage,
298 FinishReason: mappedFinishReason,
299 ProviderMetadata: fantasy.ProviderMetadata{
300 Name: providerMetadata,
301 },
302 Warnings: warnings,
303 }, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308 params, warnings, err := o.prepareParams(call)
309 if err != nil {
310 return nil, err
311 }
312
313 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314 IncludeUsage: openai.Bool(true),
315 }
316
317 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
318 isActiveText := false
319 toolCalls := make(map[int64]streamToolCall)
320
321 providerMetadata := fantasy.ProviderMetadata{
322 Name: &ProviderMetadata{},
323 }
324 acc := openai.ChatCompletionAccumulator{}
325 extraContext := make(map[string]any)
326 var usage fantasy.Usage
327 var finishReason string
328 return func(yield func(fantasy.StreamPart) bool) {
329 if len(warnings) > 0 {
330 if !yield(fantasy.StreamPart{
331 Type: fantasy.StreamPartTypeWarnings,
332 Warnings: warnings,
333 }) {
334 return
335 }
336 }
337 for stream.Next() {
338 chunk := stream.Current()
339 acc.AddChunk(chunk)
340 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341 if len(chunk.Choices) == 0 {
342 continue
343 }
344 for _, choice := range chunk.Choices {
345 if choice.FinishReason != "" {
346 finishReason = choice.FinishReason
347 }
348 switch {
349 case choice.Delta.Content != "":
350 if !isActiveText {
351 isActiveText = true
352 if !yield(fantasy.StreamPart{
353 Type: fantasy.StreamPartTypeTextStart,
354 ID: "0",
355 }) {
356 return
357 }
358 }
359 if !yield(fantasy.StreamPart{
360 Type: fantasy.StreamPartTypeTextDelta,
361 ID: "0",
362 Delta: choice.Delta.Content,
363 }) {
364 return
365 }
366 case len(choice.Delta.ToolCalls) > 0:
367 if isActiveText {
368 isActiveText = false
369 if !yield(fantasy.StreamPart{
370 Type: fantasy.StreamPartTypeTextEnd,
371 ID: "0",
372 }) {
373 return
374 }
375 }
376
377 for _, toolCallDelta := range choice.Delta.ToolCalls {
378 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379 if existingToolCall.hasFinished {
380 continue
381 }
382 if toolCallDelta.Function.Arguments != "" {
383 existingToolCall.arguments += toolCallDelta.Function.Arguments
384 }
385 if !yield(fantasy.StreamPart{
386 Type: fantasy.StreamPartTypeToolInputDelta,
387 ID: existingToolCall.id,
388 Delta: toolCallDelta.Function.Arguments,
389 }) {
390 return
391 }
392 toolCalls[toolCallDelta.Index] = existingToolCall
393 if xjson.IsValid(existingToolCall.arguments) {
394 if !yield(fantasy.StreamPart{
395 Type: fantasy.StreamPartTypeToolInputEnd,
396 ID: existingToolCall.id,
397 }) {
398 return
399 }
400
401 if !yield(fantasy.StreamPart{
402 Type: fantasy.StreamPartTypeToolCall,
403 ID: existingToolCall.id,
404 ToolCallName: existingToolCall.name,
405 ToolCallInput: existingToolCall.arguments,
406 }) {
407 return
408 }
409 existingToolCall.hasFinished = true
410 toolCalls[toolCallDelta.Index] = existingToolCall
411 }
412 } else {
413 var err error
414 if toolCallDelta.Type != "function" {
415 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416 }
417 if toolCallDelta.ID == "" {
418 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419 }
420 if toolCallDelta.Function.Name == "" {
421 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422 }
423 if err != nil {
424 yield(fantasy.StreamPart{
425 Type: fantasy.StreamPartTypeError,
426 Error: toProviderErr(stream.Err()),
427 })
428 return
429 }
430
431 if !yield(fantasy.StreamPart{
432 Type: fantasy.StreamPartTypeToolInputStart,
433 ID: toolCallDelta.ID,
434 ToolCallName: toolCallDelta.Function.Name,
435 }) {
436 return
437 }
438 toolCalls[toolCallDelta.Index] = streamToolCall{
439 id: toolCallDelta.ID,
440 name: toolCallDelta.Function.Name,
441 arguments: toolCallDelta.Function.Arguments,
442 }
443
444 exTc := toolCalls[toolCallDelta.Index]
445 if exTc.arguments != "" {
446 if !yield(fantasy.StreamPart{
447 Type: fantasy.StreamPartTypeToolInputDelta,
448 ID: exTc.id,
449 Delta: exTc.arguments,
450 }) {
451 return
452 }
453 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454 if !yield(fantasy.StreamPart{
455 Type: fantasy.StreamPartTypeToolInputEnd,
456 ID: toolCallDelta.ID,
457 }) {
458 return
459 }
460
461 if !yield(fantasy.StreamPart{
462 Type: fantasy.StreamPartTypeToolCall,
463 ID: exTc.id,
464 ToolCallName: exTc.name,
465 ToolCallInput: exTc.arguments,
466 }) {
467 return
468 }
469 exTc.hasFinished = true
470 toolCalls[toolCallDelta.Index] = exTc
471 }
472 }
473 continue
474 }
475 }
476 }
477
478 if o.streamExtraFunc != nil {
479 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480 if !shouldContinue {
481 return
482 }
483 extraContext = updatedContext
484 }
485 }
486
487 for _, choice := range chunk.Choices {
488 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489 for _, annotation := range annotations {
490 if annotation.Type == "url_citation" {
491 if !yield(fantasy.StreamPart{
492 Type: fantasy.StreamPartTypeSource,
493 ID: uuid.NewString(),
494 SourceType: fantasy.SourceTypeURL,
495 URL: annotation.URLCitation.URL,
496 Title: annotation.URLCitation.Title,
497 }) {
498 return
499 }
500 }
501 }
502 }
503 }
504 }
505 err := stream.Err()
506 if err == nil || errors.Is(err, io.EOF) {
507 if isActiveText {
508 isActiveText = false
509 if !yield(fantasy.StreamPart{
510 Type: fantasy.StreamPartTypeTextEnd,
511 ID: "0",
512 }) {
513 return
514 }
515 }
516
517 if len(acc.Choices) > 0 {
518 choice := acc.Choices[0]
519 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521 for _, annotation := range choice.Message.Annotations {
522 if annotation.Type == "url_citation" {
523 if !yield(fantasy.StreamPart{
524 Type: fantasy.StreamPartTypeSource,
525 ID: acc.ID,
526 SourceType: fantasy.SourceTypeURL,
527 URL: annotation.URLCitation.URL,
528 Title: annotation.URLCitation.Title,
529 }) {
530 return
531 }
532 }
533 }
534 }
535 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536 if len(acc.Choices) > 0 {
537 choice := acc.Choices[0]
538 if len(choice.Message.ToolCalls) > 0 {
539 mappedFinishReason = fantasy.FinishReasonToolCalls
540 }
541 }
542 yield(fantasy.StreamPart{
543 Type: fantasy.StreamPartTypeFinish,
544 Usage: usage,
545 FinishReason: mappedFinishReason,
546 ProviderMetadata: providerMetadata,
547 })
548 return
549 } else { //nolint: revive
550 yield(fantasy.StreamPart{
551 Type: fantasy.StreamPartTypeError,
552 Error: toProviderErr(err),
553 })
554 return
555 }
556 }, nil
557}
558
559func isReasoningModel(modelID string) bool {
560 return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561 strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562 strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563 strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564 strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568 return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572 return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573 strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577 return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578 strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579 strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583 for _, tool := range tools {
584 if tool.GetType() == fantasy.ToolTypeFunction {
585 ft, ok := tool.(fantasy.FunctionTool)
586 if !ok {
587 continue
588 }
589 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590 OfFunction: &openai.ChatCompletionFunctionToolParam{
591 Function: shared.FunctionDefinitionParam{
592 Name: ft.Name,
593 Description: param.NewOpt(ft.Description),
594 Parameters: openai.FunctionParameters(ft.InputSchema),
595 Strict: param.NewOpt(false),
596 },
597 Type: "function",
598 },
599 })
600 continue
601 }
602
603 warnings = append(warnings, fantasy.CallWarning{
604 Type: fantasy.CallWarningTypeUnsupportedTool,
605 Tool: tool,
606 Message: "tool is not supported",
607 })
608 }
609 if toolChoice == nil {
610 return openAiTools, openAiToolChoice, warnings
611 }
612
613 switch *toolChoice {
614 case fantasy.ToolChoiceAuto:
615 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616 OfAuto: param.NewOpt("auto"),
617 }
618 case fantasy.ToolChoiceNone:
619 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620 OfAuto: param.NewOpt("none"),
621 }
622 default:
623 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
625 Type: "function",
626 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
627 Name: string(*toolChoice),
628 },
629 },
630 }
631 }
632 return openAiTools, openAiToolChoice, warnings
633}
634
635// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
636func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
637 var annotations []openai.ChatCompletionMessageAnnotation
638
639 // Parse the raw JSON to extract annotations
640 var deltaData map[string]any
641 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
642 return annotations
643 }
644
645 // Check if annotations exist in the delta
646 if annotationsData, ok := deltaData["annotations"].([]any); ok {
647 for _, annotationData := range annotationsData {
648 if annotationMap, ok := annotationData.(map[string]any); ok {
649 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
650 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
651 url, urlOk := urlCitationData["url"].(string)
652 title, titleOk := urlCitationData["title"].(string)
653 if urlOk && titleOk {
654 annotation := openai.ChatCompletionMessageAnnotation{
655 Type: "url_citation",
656 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
657 URL: url,
658 Title: title,
659 },
660 }
661 annotations = append(annotations, annotation)
662 }
663 }
664 }
665 }
666 }
667 }
668
669 return annotations
670}
671
672// GenerateObject implements fantasy.LanguageModel.
673func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
674 switch o.objectMode {
675 case fantasy.ObjectModeText:
676 return object.GenerateWithText(ctx, o, call)
677 case fantasy.ObjectModeTool:
678 return object.GenerateWithTool(ctx, o, call)
679 default:
680 return o.generateObjectWithJSONMode(ctx, call)
681 }
682}
683
684// StreamObject implements fantasy.LanguageModel.
685func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
686 switch o.objectMode {
687 case fantasy.ObjectModeTool:
688 return object.StreamWithTool(ctx, o, call)
689 case fantasy.ObjectModeText:
690 return object.StreamWithText(ctx, o, call)
691 default:
692 return o.streamObjectWithJSONMode(ctx, call)
693 }
694}
695
696func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
697 jsonSchemaMap := schema.ToMap(call.Schema)
698
699 addAdditionalPropertiesFalse(jsonSchemaMap)
700
701 schemaName := call.SchemaName
702 if schemaName == "" {
703 schemaName = "response"
704 }
705
706 fantasyCall := fantasy.Call{
707 Prompt: call.Prompt,
708 MaxOutputTokens: call.MaxOutputTokens,
709 Temperature: call.Temperature,
710 TopP: call.TopP,
711 PresencePenalty: call.PresencePenalty,
712 FrequencyPenalty: call.FrequencyPenalty,
713 ProviderOptions: call.ProviderOptions,
714 }
715
716 params, warnings, err := o.prepareParams(fantasyCall)
717 if err != nil {
718 return nil, err
719 }
720
721 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
722 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
723 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
724 Name: schemaName,
725 Description: param.NewOpt(call.SchemaDescription),
726 Schema: jsonSchemaMap,
727 Strict: param.NewOpt(true),
728 },
729 },
730 }
731
732 response, err := o.client.Chat.Completions.New(ctx, *params)
733 if err != nil {
734 return nil, toProviderErr(err)
735 }
736
737 if len(response.Choices) == 0 {
738 usage, _ := o.usageFunc(*response)
739 return nil, &fantasy.NoObjectGeneratedError{
740 RawText: "",
741 ParseError: fmt.Errorf("no choices in response"),
742 Usage: usage,
743 FinishReason: fantasy.FinishReasonUnknown,
744 }
745 }
746
747 choice := response.Choices[0]
748 jsonText := choice.Message.Content
749
750 var obj any
751 if call.RepairText != nil {
752 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
753 } else {
754 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
755 }
756
757 usage, _ := o.usageFunc(*response)
758 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
759
760 if err != nil {
761 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
762 nogErr.Usage = usage
763 nogErr.FinishReason = finishReason
764 }
765 return nil, err
766 }
767
768 return &fantasy.ObjectResponse{
769 Object: obj,
770 RawText: jsonText,
771 Usage: usage,
772 FinishReason: finishReason,
773 Warnings: warnings,
774 }, nil
775}
776
777func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
778 jsonSchemaMap := schema.ToMap(call.Schema)
779
780 addAdditionalPropertiesFalse(jsonSchemaMap)
781
782 schemaName := call.SchemaName
783 if schemaName == "" {
784 schemaName = "response"
785 }
786
787 fantasyCall := fantasy.Call{
788 Prompt: call.Prompt,
789 MaxOutputTokens: call.MaxOutputTokens,
790 Temperature: call.Temperature,
791 TopP: call.TopP,
792 PresencePenalty: call.PresencePenalty,
793 FrequencyPenalty: call.FrequencyPenalty,
794 ProviderOptions: call.ProviderOptions,
795 }
796
797 params, warnings, err := o.prepareParams(fantasyCall)
798 if err != nil {
799 return nil, err
800 }
801
802 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
803 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
804 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
805 Name: schemaName,
806 Description: param.NewOpt(call.SchemaDescription),
807 Schema: jsonSchemaMap,
808 Strict: param.NewOpt(true),
809 },
810 },
811 }
812
813 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
814 IncludeUsage: openai.Bool(true),
815 }
816
817 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
818
819 return func(yield func(fantasy.ObjectStreamPart) bool) {
820 if len(warnings) > 0 {
821 if !yield(fantasy.ObjectStreamPart{
822 Type: fantasy.ObjectStreamPartTypeObject,
823 Warnings: warnings,
824 }) {
825 return
826 }
827 }
828
829 var accumulated string
830 var lastParsedObject any
831 var usage fantasy.Usage
832 var finishReason fantasy.FinishReason
833 var providerMetadata fantasy.ProviderMetadata
834 var streamErr error
835
836 for stream.Next() {
837 chunk := stream.Current()
838
839 // Update usage
840 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
841
842 if len(chunk.Choices) == 0 {
843 continue
844 }
845
846 choice := chunk.Choices[0]
847 if choice.FinishReason != "" {
848 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
849 }
850
851 if choice.Delta.Content != "" {
852 accumulated += choice.Delta.Content
853
854 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
855
856 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
857 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
858 if !reflect.DeepEqual(obj, lastParsedObject) {
859 if !yield(fantasy.ObjectStreamPart{
860 Type: fantasy.ObjectStreamPartTypeObject,
861 Object: obj,
862 }) {
863 return
864 }
865 lastParsedObject = obj
866 }
867 }
868 }
869
870 if state == schema.ParseStateFailed && call.RepairText != nil {
871 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
872 if repairErr == nil {
873 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
874 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
875 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
876 if !reflect.DeepEqual(obj2, lastParsedObject) {
877 if !yield(fantasy.ObjectStreamPart{
878 Type: fantasy.ObjectStreamPartTypeObject,
879 Object: obj2,
880 }) {
881 return
882 }
883 lastParsedObject = obj2
884 }
885 }
886 }
887 }
888 }
889 }
890
891 err := stream.Err()
892 if err != nil && !errors.Is(err, io.EOF) {
893 streamErr = toProviderErr(err)
894 yield(fantasy.ObjectStreamPart{
895 Type: fantasy.ObjectStreamPartTypeError,
896 Error: streamErr,
897 })
898 return
899 }
900
901 if lastParsedObject != nil {
902 yield(fantasy.ObjectStreamPart{
903 Type: fantasy.ObjectStreamPartTypeFinish,
904 Usage: usage,
905 FinishReason: finishReason,
906 ProviderMetadata: providerMetadata,
907 })
908 } else {
909 yield(fantasy.ObjectStreamPart{
910 Type: fantasy.ObjectStreamPartTypeError,
911 Error: &fantasy.NoObjectGeneratedError{
912 RawText: accumulated,
913 ParseError: fmt.Errorf("no valid object generated in stream"),
914 Usage: usage,
915 FinishReason: finishReason,
916 },
917 })
918 }
919 }, nil
920}
921
922// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
923// This is required by OpenAI's strict mode for structured outputs.
924func addAdditionalPropertiesFalse(schema map[string]any) {
925 if schema["type"] == "object" {
926 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
927 schema["additionalProperties"] = false
928 }
929
930 // Recursively process nested properties
931 if properties, ok := schema["properties"].(map[string]any); ok {
932 for _, propValue := range properties {
933 if propSchema, ok := propValue.(map[string]any); ok {
934 addAdditionalPropertiesFalse(propSchema)
935 }
936 }
937 }
938 }
939
940 // Handle array items
941 if items, ok := schema["items"].(map[string]any); ok {
942 addAdditionalPropertiesFalse(items)
943 }
944}