1package openai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "strings"
11
12 "charm.land/fantasy"
13 "charm.land/fantasy/object"
14 "charm.land/fantasy/schema"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type languageModel struct {
23 provider string
24 modelID string
25 client openai.Client
26 objectMode fantasy.ObjectMode
27 prepareCallFunc LanguageModelPrepareCallFunc
28 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
29 extraContentFunc LanguageModelExtraContentFunc
30 usageFunc LanguageModelUsageFunc
31 streamUsageFunc LanguageModelStreamUsageFunc
32 streamExtraFunc LanguageModelStreamExtraFunc
33 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
34 toPromptFunc LanguageModelToPromptFunc
35}
36
37// LanguageModelOption is a function that configures a languageModel.
38type LanguageModelOption = func(*languageModel)
39
40// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
41func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
42 return func(l *languageModel) {
43 l.prepareCallFunc = fn
44 }
45}
46
47// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
48func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
49 return func(l *languageModel) {
50 l.mapFinishReasonFunc = fn
51 }
52}
53
54// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
55func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
56 return func(l *languageModel) {
57 l.extraContentFunc = fn
58 }
59}
60
61// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
62func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
63 return func(l *languageModel) {
64 l.streamExtraFunc = fn
65 }
66}
67
68// WithLanguageModelUsageFunc sets the usage function for the language model.
69func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
70 return func(l *languageModel) {
71 l.usageFunc = fn
72 }
73}
74
75// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
76func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
77 return func(l *languageModel) {
78 l.streamUsageFunc = fn
79 }
80}
81
82// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
83func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
84 return func(l *languageModel) {
85 l.toPromptFunc = fn
86 }
87}
88
89// WithLanguageModelObjectMode sets the object generation mode.
90func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
91 return func(l *languageModel) {
92 // not supported
93 if om == fantasy.ObjectModeJSON {
94 om = fantasy.ObjectModeAuto
95 }
96 l.objectMode = om
97 }
98}
99
100func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
101 model := languageModel{
102 modelID: modelID,
103 provider: provider,
104 client: client,
105 objectMode: fantasy.ObjectModeAuto,
106 prepareCallFunc: DefaultPrepareCallFunc,
107 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
108 usageFunc: DefaultUsageFunc,
109 streamUsageFunc: DefaultStreamUsageFunc,
110 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
111 toPromptFunc: DefaultToPrompt,
112 }
113
114 for _, o := range opts {
115 o(&model)
116 }
117 return model
118}
119
120type streamToolCall struct {
121 id string
122 name string
123 arguments string
124 hasFinished bool
125}
126
127// Model implements fantasy.LanguageModel.
128func (o languageModel) Model() string {
129 return o.modelID
130}
131
132// Provider implements fantasy.LanguageModel.
133func (o languageModel) Provider() string {
134 return o.provider
135}
136
137func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
138 params := &openai.ChatCompletionNewParams{}
139 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
140 if call.TopK != nil {
141 warnings = append(warnings, fantasy.CallWarning{
142 Type: fantasy.CallWarningTypeUnsupportedSetting,
143 Setting: "top_k",
144 })
145 }
146
147 if call.MaxOutputTokens != nil {
148 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
149 }
150 if call.Temperature != nil {
151 params.Temperature = param.NewOpt(*call.Temperature)
152 }
153 if call.TopP != nil {
154 params.TopP = param.NewOpt(*call.TopP)
155 }
156 if call.FrequencyPenalty != nil {
157 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
158 }
159 if call.PresencePenalty != nil {
160 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
161 }
162
163 if isReasoningModel(o.modelID) {
164 // remove unsupported settings for reasoning models
165 // see https://platform.openai.com/docs/guides/reasoning#limitations
166 if call.Temperature != nil {
167 params.Temperature = param.Opt[float64]{}
168 warnings = append(warnings, fantasy.CallWarning{
169 Type: fantasy.CallWarningTypeUnsupportedSetting,
170 Setting: "temperature",
171 Details: "temperature is not supported for reasoning models",
172 })
173 }
174 if call.TopP != nil {
175 params.TopP = param.Opt[float64]{}
176 warnings = append(warnings, fantasy.CallWarning{
177 Type: fantasy.CallWarningTypeUnsupportedSetting,
178 Setting: "TopP",
179 Details: "TopP is not supported for reasoning models",
180 })
181 }
182 if call.FrequencyPenalty != nil {
183 params.FrequencyPenalty = param.Opt[float64]{}
184 warnings = append(warnings, fantasy.CallWarning{
185 Type: fantasy.CallWarningTypeUnsupportedSetting,
186 Setting: "FrequencyPenalty",
187 Details: "FrequencyPenalty is not supported for reasoning models",
188 })
189 }
190 if call.PresencePenalty != nil {
191 params.PresencePenalty = param.Opt[float64]{}
192 warnings = append(warnings, fantasy.CallWarning{
193 Type: fantasy.CallWarningTypeUnsupportedSetting,
194 Setting: "PresencePenalty",
195 Details: "PresencePenalty is not supported for reasoning models",
196 })
197 }
198
199 // reasoning models use max_completion_tokens instead of max_tokens
200 if call.MaxOutputTokens != nil {
201 if !params.MaxCompletionTokens.Valid() {
202 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
203 }
204 params.MaxTokens = param.Opt[int64]{}
205 }
206 }
207
208 // Handle search preview models
209 if isSearchPreviewModel(o.modelID) {
210 if call.Temperature != nil {
211 params.Temperature = param.Opt[float64]{}
212 warnings = append(warnings, fantasy.CallWarning{
213 Type: fantasy.CallWarningTypeUnsupportedSetting,
214 Setting: "temperature",
215 Details: "temperature is not supported for the search preview models and has been removed.",
216 })
217 }
218 }
219
220 optionsWarnings, err := o.prepareCallFunc(o, params, call)
221 if err != nil {
222 return nil, nil, err
223 }
224
225 if len(optionsWarnings) > 0 {
226 warnings = append(warnings, optionsWarnings...)
227 }
228
229 params.Messages = messages
230 params.Model = o.modelID
231
232 if len(call.Tools) > 0 {
233 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
234 params.Tools = tools
235 if toolChoice != nil {
236 params.ToolChoice = *toolChoice
237 }
238 warnings = append(warnings, toolWarnings...)
239 }
240 return params, warnings, nil
241}
242
243// Generate implements fantasy.LanguageModel.
244func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
245 params, warnings, err := o.prepareParams(call)
246 if err != nil {
247 return nil, err
248 }
249 response, err := o.client.Chat.Completions.New(ctx, *params, callUARequestOptions(call)...)
250 if err != nil {
251 return nil, toProviderErr(err)
252 }
253
254 if len(response.Choices) == 0 {
255 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
256 }
257 choice := response.Choices[0]
258 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
259 text := choice.Message.Content
260 if text != "" {
261 content = append(content, fantasy.TextContent{
262 Text: text,
263 })
264 }
265 if o.extraContentFunc != nil {
266 extraContent := o.extraContentFunc(choice)
267 content = append(content, extraContent...)
268 }
269 for _, tc := range choice.Message.ToolCalls {
270 toolCallID := tc.ID
271 content = append(content, fantasy.ToolCallContent{
272 ProviderExecuted: false,
273 ToolCallID: toolCallID,
274 ToolName: tc.Function.Name,
275 Input: tc.Function.Arguments,
276 })
277 }
278 for _, annotation := range choice.Message.Annotations {
279 if annotation.Type == "url_citation" {
280 content = append(content, fantasy.SourceContent{
281 SourceType: fantasy.SourceTypeURL,
282 ID: uuid.NewString(),
283 URL: annotation.URLCitation.URL,
284 Title: annotation.URLCitation.Title,
285 })
286 }
287 }
288
289 usage, providerMetadata := o.usageFunc(*response)
290
291 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
292 if len(choice.Message.ToolCalls) > 0 {
293 mappedFinishReason = fantasy.FinishReasonToolCalls
294 }
295 return &fantasy.Response{
296 Content: content,
297 Usage: usage,
298 FinishReason: mappedFinishReason,
299 ProviderMetadata: fantasy.ProviderMetadata{
300 Name: providerMetadata,
301 },
302 Warnings: warnings,
303 }, nil
304}
305
306// Stream implements fantasy.LanguageModel.
307func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
308 params, warnings, err := o.prepareParams(call)
309 if err != nil {
310 return nil, err
311 }
312
313 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
314 IncludeUsage: openai.Bool(true),
315 }
316
317 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, callUARequestOptions(call)...)
318 isActiveText := false
319 toolCalls := make(map[int64]streamToolCall)
320
321 providerMetadata := fantasy.ProviderMetadata{
322 Name: &ProviderMetadata{},
323 }
324 acc := openai.ChatCompletionAccumulator{}
325 extraContext := make(map[string]any)
326 var usage fantasy.Usage
327 var finishReason string
328 return func(yield func(fantasy.StreamPart) bool) {
329 if len(warnings) > 0 {
330 if !yield(fantasy.StreamPart{
331 Type: fantasy.StreamPartTypeWarnings,
332 Warnings: warnings,
333 }) {
334 return
335 }
336 }
337 for stream.Next() {
338 chunk := stream.Current()
339 acc.AddChunk(chunk)
340 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
341 if len(chunk.Choices) == 0 {
342 continue
343 }
344 for _, choice := range chunk.Choices {
345 if choice.FinishReason != "" {
346 finishReason = choice.FinishReason
347 }
348 switch {
349 case choice.Delta.Content != "":
350 if !isActiveText {
351 isActiveText = true
352 if !yield(fantasy.StreamPart{
353 Type: fantasy.StreamPartTypeTextStart,
354 ID: "0",
355 }) {
356 return
357 }
358 }
359 if !yield(fantasy.StreamPart{
360 Type: fantasy.StreamPartTypeTextDelta,
361 ID: "0",
362 Delta: choice.Delta.Content,
363 }) {
364 return
365 }
366 case len(choice.Delta.ToolCalls) > 0:
367 if isActiveText {
368 isActiveText = false
369 if !yield(fantasy.StreamPart{
370 Type: fantasy.StreamPartTypeTextEnd,
371 ID: "0",
372 }) {
373 return
374 }
375 }
376
377 for _, toolCallDelta := range choice.Delta.ToolCalls {
378 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
379 if existingToolCall.hasFinished {
380 continue
381 }
382 if toolCallDelta.Function.Arguments != "" {
383 existingToolCall.arguments += toolCallDelta.Function.Arguments
384 }
385 if !yield(fantasy.StreamPart{
386 Type: fantasy.StreamPartTypeToolInputDelta,
387 ID: existingToolCall.id,
388 Delta: toolCallDelta.Function.Arguments,
389 }) {
390 return
391 }
392 toolCalls[toolCallDelta.Index] = existingToolCall
393 if xjson.IsValid(existingToolCall.arguments) {
394 if !yield(fantasy.StreamPart{
395 Type: fantasy.StreamPartTypeToolInputEnd,
396 ID: existingToolCall.id,
397 }) {
398 return
399 }
400
401 if !yield(fantasy.StreamPart{
402 Type: fantasy.StreamPartTypeToolCall,
403 ID: existingToolCall.id,
404 ToolCallName: existingToolCall.name,
405 ToolCallInput: existingToolCall.arguments,
406 }) {
407 return
408 }
409 existingToolCall.hasFinished = true
410 toolCalls[toolCallDelta.Index] = existingToolCall
411 }
412 } else {
413 var err error
414 if toolCallDelta.Type != "function" {
415 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
416 }
417 if toolCallDelta.ID == "" {
418 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
419 }
420 if toolCallDelta.Function.Name == "" {
421 err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
422 }
423 if err != nil {
424 yield(fantasy.StreamPart{
425 Type: fantasy.StreamPartTypeError,
426 Error: toProviderErr(stream.Err()),
427 })
428 return
429 }
430
431 if !yield(fantasy.StreamPart{
432 Type: fantasy.StreamPartTypeToolInputStart,
433 ID: toolCallDelta.ID,
434 ToolCallName: toolCallDelta.Function.Name,
435 }) {
436 return
437 }
438 toolCalls[toolCallDelta.Index] = streamToolCall{
439 id: toolCallDelta.ID,
440 name: toolCallDelta.Function.Name,
441 arguments: toolCallDelta.Function.Arguments,
442 }
443
444 exTc := toolCalls[toolCallDelta.Index]
445 if exTc.arguments != "" {
446 if !yield(fantasy.StreamPart{
447 Type: fantasy.StreamPartTypeToolInputDelta,
448 ID: exTc.id,
449 Delta: exTc.arguments,
450 }) {
451 return
452 }
453 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
454 if !yield(fantasy.StreamPart{
455 Type: fantasy.StreamPartTypeToolInputEnd,
456 ID: toolCallDelta.ID,
457 }) {
458 return
459 }
460
461 if !yield(fantasy.StreamPart{
462 Type: fantasy.StreamPartTypeToolCall,
463 ID: exTc.id,
464 ToolCallName: exTc.name,
465 ToolCallInput: exTc.arguments,
466 }) {
467 return
468 }
469 exTc.hasFinished = true
470 toolCalls[toolCallDelta.Index] = exTc
471 }
472 }
473 continue
474 }
475 }
476 }
477
478 if o.streamExtraFunc != nil {
479 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
480 if !shouldContinue {
481 return
482 }
483 extraContext = updatedContext
484 }
485 }
486
487 for _, choice := range chunk.Choices {
488 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
489 for _, annotation := range annotations {
490 if annotation.Type == "url_citation" {
491 if !yield(fantasy.StreamPart{
492 Type: fantasy.StreamPartTypeSource,
493 ID: uuid.NewString(),
494 SourceType: fantasy.SourceTypeURL,
495 URL: annotation.URLCitation.URL,
496 Title: annotation.URLCitation.Title,
497 }) {
498 return
499 }
500 }
501 }
502 }
503 }
504 }
505 err := stream.Err()
506 if err == nil || errors.Is(err, io.EOF) {
507 if isActiveText {
508 isActiveText = false
509 if !yield(fantasy.StreamPart{
510 Type: fantasy.StreamPartTypeTextEnd,
511 ID: "0",
512 }) {
513 return
514 }
515 }
516
517 if len(acc.Choices) > 0 {
518 choice := acc.Choices[0]
519 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
520
521 for _, annotation := range choice.Message.Annotations {
522 if annotation.Type == "url_citation" {
523 if !yield(fantasy.StreamPart{
524 Type: fantasy.StreamPartTypeSource,
525 ID: acc.ID,
526 SourceType: fantasy.SourceTypeURL,
527 URL: annotation.URLCitation.URL,
528 Title: annotation.URLCitation.Title,
529 }) {
530 return
531 }
532 }
533 }
534 }
535 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
536 if len(acc.Choices) > 0 {
537 choice := acc.Choices[0]
538 if len(choice.Message.ToolCalls) > 0 {
539 mappedFinishReason = fantasy.FinishReasonToolCalls
540 }
541 }
542 yield(fantasy.StreamPart{
543 Type: fantasy.StreamPartTypeFinish,
544 Usage: usage,
545 FinishReason: mappedFinishReason,
546 ProviderMetadata: providerMetadata,
547 })
548 return
549 } else { //nolint: revive
550 yield(fantasy.StreamPart{
551 Type: fantasy.StreamPartTypeError,
552 Error: toProviderErr(err),
553 })
554 return
555 }
556 }, nil
557}
558
559func isReasoningModel(modelID string) bool {
560 return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
561 strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
562 strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
563 strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
564 strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
565}
566
567func isSearchPreviewModel(modelID string) bool {
568 return strings.Contains(modelID, "search-preview")
569}
570
571func supportsFlexProcessing(modelID string) bool {
572 return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
573 strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
574}
575
576func supportsPriorityProcessing(modelID string) bool {
577 return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
578 strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
579 strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
580}
581
582func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
583 for _, tool := range tools {
584 if tool.GetType() == fantasy.ToolTypeFunction {
585 ft, ok := tool.(fantasy.FunctionTool)
586 if !ok {
587 continue
588 }
589 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
590 OfFunction: &openai.ChatCompletionFunctionToolParam{
591 Function: shared.FunctionDefinitionParam{
592 Name: ft.Name,
593 Description: param.NewOpt(ft.Description),
594 Parameters: openai.FunctionParameters(ft.InputSchema),
595 Strict: param.NewOpt(false),
596 },
597 Type: "function",
598 },
599 })
600 continue
601 }
602
603 warnings = append(warnings, fantasy.CallWarning{
604 Type: fantasy.CallWarningTypeUnsupportedTool,
605 Tool: tool,
606 Message: "tool is not supported",
607 })
608 }
609 if toolChoice == nil {
610 return openAiTools, openAiToolChoice, warnings
611 }
612
613 switch *toolChoice {
614 case fantasy.ToolChoiceAuto:
615 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
616 OfAuto: param.NewOpt("auto"),
617 }
618 case fantasy.ToolChoiceNone:
619 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
620 OfAuto: param.NewOpt("none"),
621 }
622 case fantasy.ToolChoiceRequired:
623 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
624 OfAuto: param.NewOpt("required"),
625 }
626 default:
627 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
628 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
629 Type: "function",
630 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
631 Name: string(*toolChoice),
632 },
633 },
634 }
635 }
636 return openAiTools, openAiToolChoice, warnings
637}
638
639// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
640func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
641 var annotations []openai.ChatCompletionMessageAnnotation
642
643 // Parse the raw JSON to extract annotations
644 var deltaData map[string]any
645 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
646 return annotations
647 }
648
649 // Check if annotations exist in the delta
650 if annotationsData, ok := deltaData["annotations"].([]any); ok {
651 for _, annotationData := range annotationsData {
652 if annotationMap, ok := annotationData.(map[string]any); ok {
653 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
654 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
655 url, urlOk := urlCitationData["url"].(string)
656 title, titleOk := urlCitationData["title"].(string)
657 if urlOk && titleOk {
658 annotation := openai.ChatCompletionMessageAnnotation{
659 Type: "url_citation",
660 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
661 URL: url,
662 Title: title,
663 },
664 }
665 annotations = append(annotations, annotation)
666 }
667 }
668 }
669 }
670 }
671 }
672
673 return annotations
674}
675
676// GenerateObject implements fantasy.LanguageModel.
677func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
678 switch o.objectMode {
679 case fantasy.ObjectModeText:
680 return object.GenerateWithText(ctx, o, call)
681 case fantasy.ObjectModeTool:
682 return object.GenerateWithTool(ctx, o, call)
683 default:
684 return o.generateObjectWithJSONMode(ctx, call)
685 }
686}
687
688// StreamObject implements fantasy.LanguageModel.
689func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
690 switch o.objectMode {
691 case fantasy.ObjectModeTool:
692 return object.StreamWithTool(ctx, o, call)
693 case fantasy.ObjectModeText:
694 return object.StreamWithText(ctx, o, call)
695 default:
696 return o.streamObjectWithJSONMode(ctx, call)
697 }
698}
699
700func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
701 jsonSchemaMap := schema.ToMap(call.Schema)
702
703 addAdditionalPropertiesFalse(jsonSchemaMap)
704
705 schemaName := call.SchemaName
706 if schemaName == "" {
707 schemaName = "response"
708 }
709
710 fantasyCall := fantasy.Call{
711 Prompt: call.Prompt,
712 MaxOutputTokens: call.MaxOutputTokens,
713 Temperature: call.Temperature,
714 TopP: call.TopP,
715 PresencePenalty: call.PresencePenalty,
716 FrequencyPenalty: call.FrequencyPenalty,
717 ProviderOptions: call.ProviderOptions,
718 }
719
720 params, warnings, err := o.prepareParams(fantasyCall)
721 if err != nil {
722 return nil, err
723 }
724
725 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
726 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
727 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
728 Name: schemaName,
729 Description: param.NewOpt(call.SchemaDescription),
730 Schema: jsonSchemaMap,
731 Strict: param.NewOpt(true),
732 },
733 },
734 }
735
736 response, err := o.client.Chat.Completions.New(ctx, *params, objectCallUARequestOptions(call)...)
737 if err != nil {
738 return nil, toProviderErr(err)
739 }
740 if len(response.Choices) == 0 {
741 usage, _ := o.usageFunc(*response)
742 return nil, &fantasy.NoObjectGeneratedError{
743 RawText: "",
744 ParseError: fmt.Errorf("no choices in response"),
745 Usage: usage,
746 FinishReason: fantasy.FinishReasonUnknown,
747 }
748 }
749
750 choice := response.Choices[0]
751 jsonText := choice.Message.Content
752
753 var obj any
754 if call.RepairText != nil {
755 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
756 } else {
757 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
758 }
759
760 usage, _ := o.usageFunc(*response)
761 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
762
763 if err != nil {
764 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
765 nogErr.Usage = usage
766 nogErr.FinishReason = finishReason
767 }
768 return nil, err
769 }
770
771 return &fantasy.ObjectResponse{
772 Object: obj,
773 RawText: jsonText,
774 Usage: usage,
775 FinishReason: finishReason,
776 Warnings: warnings,
777 }, nil
778}
779
780func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
781 jsonSchemaMap := schema.ToMap(call.Schema)
782
783 addAdditionalPropertiesFalse(jsonSchemaMap)
784
785 schemaName := call.SchemaName
786 if schemaName == "" {
787 schemaName = "response"
788 }
789
790 fantasyCall := fantasy.Call{
791 Prompt: call.Prompt,
792 MaxOutputTokens: call.MaxOutputTokens,
793 Temperature: call.Temperature,
794 TopP: call.TopP,
795 PresencePenalty: call.PresencePenalty,
796 FrequencyPenalty: call.FrequencyPenalty,
797 ProviderOptions: call.ProviderOptions,
798 }
799
800 params, warnings, err := o.prepareParams(fantasyCall)
801 if err != nil {
802 return nil, err
803 }
804
805 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
806 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
807 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
808 Name: schemaName,
809 Description: param.NewOpt(call.SchemaDescription),
810 Schema: jsonSchemaMap,
811 Strict: param.NewOpt(true),
812 },
813 },
814 }
815
816 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
817 IncludeUsage: openai.Bool(true),
818 }
819
820 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...)
821
822 return func(yield func(fantasy.ObjectStreamPart) bool) {
823 if len(warnings) > 0 {
824 if !yield(fantasy.ObjectStreamPart{
825 Type: fantasy.ObjectStreamPartTypeObject,
826 Warnings: warnings,
827 }) {
828 return
829 }
830 }
831
832 var accumulated string
833 var lastParsedObject any
834 var usage fantasy.Usage
835 var finishReason fantasy.FinishReason
836 var providerMetadata fantasy.ProviderMetadata
837 var streamErr error
838
839 for stream.Next() {
840 chunk := stream.Current()
841
842 // Update usage
843 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
844
845 if len(chunk.Choices) == 0 {
846 continue
847 }
848
849 choice := chunk.Choices[0]
850 if choice.FinishReason != "" {
851 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
852 }
853
854 if choice.Delta.Content != "" {
855 accumulated += choice.Delta.Content
856
857 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
858
859 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
860 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
861 if !reflect.DeepEqual(obj, lastParsedObject) {
862 if !yield(fantasy.ObjectStreamPart{
863 Type: fantasy.ObjectStreamPartTypeObject,
864 Object: obj,
865 }) {
866 return
867 }
868 lastParsedObject = obj
869 }
870 }
871 }
872
873 if state == schema.ParseStateFailed && call.RepairText != nil {
874 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
875 if repairErr == nil {
876 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
877 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
878 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
879 if !reflect.DeepEqual(obj2, lastParsedObject) {
880 if !yield(fantasy.ObjectStreamPart{
881 Type: fantasy.ObjectStreamPartTypeObject,
882 Object: obj2,
883 }) {
884 return
885 }
886 lastParsedObject = obj2
887 }
888 }
889 }
890 }
891 }
892 }
893
894 err := stream.Err()
895 if err != nil && !errors.Is(err, io.EOF) {
896 streamErr = toProviderErr(err)
897 yield(fantasy.ObjectStreamPart{
898 Type: fantasy.ObjectStreamPartTypeError,
899 Error: streamErr,
900 })
901 return
902 }
903
904 if lastParsedObject != nil {
905 yield(fantasy.ObjectStreamPart{
906 Type: fantasy.ObjectStreamPartTypeFinish,
907 Usage: usage,
908 FinishReason: finishReason,
909 ProviderMetadata: providerMetadata,
910 })
911 } else {
912 yield(fantasy.ObjectStreamPart{
913 Type: fantasy.ObjectStreamPartTypeError,
914 Error: &fantasy.NoObjectGeneratedError{
915 RawText: accumulated,
916 ParseError: fmt.Errorf("no valid object generated in stream"),
917 Usage: usage,
918 FinishReason: finishReason,
919 },
920 })
921 }
922 }, nil
923}
924
925// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
926// This is required by OpenAI's strict mode for structured outputs.
927func addAdditionalPropertiesFalse(schema map[string]any) {
928 if schema["type"] == "object" {
929 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
930 schema["additionalProperties"] = false
931 }
932
933 // Recursively process nested properties
934 if properties, ok := schema["properties"].(map[string]any); ok {
935 for _, propValue := range properties {
936 if propSchema, ok := propValue.(map[string]any); ok {
937 addAdditionalPropertiesFalse(propSchema)
938 }
939 }
940 }
941 }
942
943 // Handle array items
944 if items, ok := schema["items"].(map[string]any); ok {
945 addAdditionalPropertiesFalse(items)
946 }
947}