1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "reflect"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/object"
15 "charm.land/fantasy/schema"
16 "github.com/charmbracelet/openai-go"
17 "github.com/charmbracelet/openai-go/packages/param"
18 "github.com/charmbracelet/openai-go/shared"
19 xjson "github.com/charmbracelet/x/json"
20 "github.com/google/uuid"
21)
22
23type languageModel struct {
24 provider string
25 modelID string
26 client openai.Client
27 objectMode fantasy.ObjectMode
28 prepareCallFunc LanguageModelPrepareCallFunc
29 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
30 extraContentFunc LanguageModelExtraContentFunc
31 usageFunc LanguageModelUsageFunc
32 streamUsageFunc LanguageModelStreamUsageFunc
33 streamExtraFunc LanguageModelStreamExtraFunc
34 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
35 toPromptFunc LanguageModelToPromptFunc
36}
37
38// LanguageModelOption is a function that configures a languageModel.
39type LanguageModelOption = func(*languageModel)
40
41// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
42func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
43 return func(l *languageModel) {
44 l.prepareCallFunc = fn
45 }
46}
47
48// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
49func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
50 return func(l *languageModel) {
51 l.mapFinishReasonFunc = fn
52 }
53}
54
55// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
56func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
57 return func(l *languageModel) {
58 l.extraContentFunc = fn
59 }
60}
61
62// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
63func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
64 return func(l *languageModel) {
65 l.streamExtraFunc = fn
66 }
67}
68
69// WithLanguageModelUsageFunc sets the usage function for the language model.
70func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
71 return func(l *languageModel) {
72 l.usageFunc = fn
73 }
74}
75
76// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
77func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
78 return func(l *languageModel) {
79 l.streamUsageFunc = fn
80 }
81}
82
83// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
84func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
85 return func(l *languageModel) {
86 l.toPromptFunc = fn
87 }
88}
89
90// WithLanguageModelObjectMode sets the object generation mode.
91func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
92 return func(l *languageModel) {
93 // not supported
94 if om == fantasy.ObjectModeJSON {
95 om = fantasy.ObjectModeAuto
96 }
97 l.objectMode = om
98 }
99}
100
101func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
102 model := languageModel{
103 modelID: modelID,
104 provider: provider,
105 client: client,
106 objectMode: fantasy.ObjectModeAuto,
107 prepareCallFunc: DefaultPrepareCallFunc,
108 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
109 usageFunc: DefaultUsageFunc,
110 streamUsageFunc: DefaultStreamUsageFunc,
111 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
112 toPromptFunc: DefaultToPrompt,
113 }
114
115 for _, o := range opts {
116 o(&model)
117 }
118 return model
119}
120
121type streamToolCall struct {
122 id string
123 name string
124 arguments string
125 hasFinished bool
126}
127
128// Model implements fantasy.LanguageModel.
129func (o languageModel) Model() string {
130 return o.modelID
131}
132
133// Provider implements fantasy.LanguageModel.
134func (o languageModel) Provider() string {
135 return o.provider
136}
137
138func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
139 params := &openai.ChatCompletionNewParams{}
140 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
141 if call.TopK != nil {
142 warnings = append(warnings, fantasy.CallWarning{
143 Type: fantasy.CallWarningTypeUnsupportedSetting,
144 Setting: "top_k",
145 })
146 }
147
148 if call.MaxOutputTokens != nil {
149 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
150 }
151 if call.Temperature != nil {
152 params.Temperature = param.NewOpt(*call.Temperature)
153 }
154 if call.TopP != nil {
155 params.TopP = param.NewOpt(*call.TopP)
156 }
157 if call.FrequencyPenalty != nil {
158 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
159 }
160 if call.PresencePenalty != nil {
161 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
162 }
163
164 if isReasoningModel(o.modelID) {
165 // remove unsupported settings for reasoning models
166 // see https://platform.openai.com/docs/guides/reasoning#limitations
167 if call.Temperature != nil {
168 params.Temperature = param.Opt[float64]{}
169 warnings = append(warnings, fantasy.CallWarning{
170 Type: fantasy.CallWarningTypeUnsupportedSetting,
171 Setting: "temperature",
172 Details: "temperature is not supported for reasoning models",
173 })
174 }
175 if call.TopP != nil {
176 params.TopP = param.Opt[float64]{}
177 warnings = append(warnings, fantasy.CallWarning{
178 Type: fantasy.CallWarningTypeUnsupportedSetting,
179 Setting: "TopP",
180 Details: "TopP is not supported for reasoning models",
181 })
182 }
183 if call.FrequencyPenalty != nil {
184 params.FrequencyPenalty = param.Opt[float64]{}
185 warnings = append(warnings, fantasy.CallWarning{
186 Type: fantasy.CallWarningTypeUnsupportedSetting,
187 Setting: "FrequencyPenalty",
188 Details: "FrequencyPenalty is not supported for reasoning models",
189 })
190 }
191 if call.PresencePenalty != nil {
192 params.PresencePenalty = param.Opt[float64]{}
193 warnings = append(warnings, fantasy.CallWarning{
194 Type: fantasy.CallWarningTypeUnsupportedSetting,
195 Setting: "PresencePenalty",
196 Details: "PresencePenalty is not supported for reasoning models",
197 })
198 }
199
200 // reasoning models use max_completion_tokens instead of max_tokens
201 if call.MaxOutputTokens != nil {
202 if !params.MaxCompletionTokens.Valid() {
203 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
204 }
205 params.MaxTokens = param.Opt[int64]{}
206 }
207 }
208
209 // Handle search preview models
210 if isSearchPreviewModel(o.modelID) {
211 if call.Temperature != nil {
212 params.Temperature = param.Opt[float64]{}
213 warnings = append(warnings, fantasy.CallWarning{
214 Type: fantasy.CallWarningTypeUnsupportedSetting,
215 Setting: "temperature",
216 Details: "temperature is not supported for the search preview models and has been removed.",
217 })
218 }
219 }
220
221 optionsWarnings, err := o.prepareCallFunc(o, params, call)
222 if err != nil {
223 return nil, nil, err
224 }
225
226 if len(optionsWarnings) > 0 {
227 warnings = append(warnings, optionsWarnings...)
228 }
229
230 params.Messages = messages
231 params.Model = o.modelID
232
233 if len(call.Tools) > 0 {
234 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
235 params.Tools = tools
236 if toolChoice != nil {
237 params.ToolChoice = *toolChoice
238 }
239 warnings = append(warnings, toolWarnings...)
240 }
241 return params, warnings, nil
242}
243
244// Generate implements fantasy.LanguageModel.
245func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
246 params, warnings, err := o.prepareParams(call)
247 if err != nil {
248 return nil, err
249 }
250 response, err := o.client.Chat.Completions.New(ctx, *params, callUARequestOptions(call)...)
251 if err != nil {
252 return nil, toProviderErr(err)
253 }
254 if response == nil {
255 return nil, &fantasy.Error{Title: "no response", Message: "provider returned nil response"}
256 }
257
258 if len(response.Choices) == 0 {
259 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
260 }
261 choice := response.Choices[0]
262 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
263 text := choice.Message.Content
264 if text != "" {
265 content = append(content, fantasy.TextContent{
266 Text: text,
267 })
268 }
269 if o.extraContentFunc != nil {
270 extraContent := o.extraContentFunc(choice)
271 content = append(content, extraContent...)
272 }
273 for _, tc := range choice.Message.ToolCalls {
274 toolCallID := tc.ID
275 content = append(content, fantasy.ToolCallContent{
276 ProviderExecuted: false,
277 ToolCallID: toolCallID,
278 ToolName: tc.Function.Name,
279 Input: tc.Function.Arguments,
280 })
281 }
282 for _, annotation := range choice.Message.Annotations {
283 if annotation.Type == "url_citation" {
284 content = append(content, fantasy.SourceContent{
285 SourceType: fantasy.SourceTypeURL,
286 ID: uuid.NewString(),
287 URL: annotation.URLCitation.URL,
288 Title: annotation.URLCitation.Title,
289 })
290 }
291 }
292
293 usage, providerMetadata := o.usageFunc(*response)
294
295 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
296 if len(choice.Message.ToolCalls) > 0 {
297 mappedFinishReason = fantasy.FinishReasonToolCalls
298 }
299 return &fantasy.Response{
300 Content: content,
301 Usage: usage,
302 FinishReason: mappedFinishReason,
303 ProviderMetadata: fantasy.ProviderMetadata{
304 Name: providerMetadata,
305 },
306 Warnings: warnings,
307 }, nil
308}
309
310// Stream implements fantasy.LanguageModel.
311func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
312 params, warnings, err := o.prepareParams(call)
313 if err != nil {
314 return nil, err
315 }
316
317 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
318 IncludeUsage: openai.Bool(true),
319 }
320
321 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, callUARequestOptions(call)...)
322 isActiveText := false
323 toolCalls := make(map[int64]streamToolCall)
324
325 providerMetadata := fantasy.ProviderMetadata{
326 Name: &ProviderMetadata{},
327 }
328 acc := openai.ChatCompletionAccumulator{}
329 extraContext := make(map[string]any)
330 var usage fantasy.Usage
331 var finishReason string
332 return func(yield func(fantasy.StreamPart) bool) {
333 if len(warnings) > 0 {
334 if !yield(fantasy.StreamPart{
335 Type: fantasy.StreamPartTypeWarnings,
336 Warnings: warnings,
337 }) {
338 return
339 }
340 }
341 for stream.Next() {
342 chunk := stream.Current()
343 acc.AddChunk(chunk)
344 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
345 if len(chunk.Choices) == 0 {
346 continue
347 }
348 for _, choice := range chunk.Choices {
349 if choice.FinishReason != "" {
350 finishReason = choice.FinishReason
351 }
352 switch {
353 case choice.Delta.Content != "":
354 if !isActiveText {
355 isActiveText = true
356 if !yield(fantasy.StreamPart{
357 Type: fantasy.StreamPartTypeTextStart,
358 ID: "0",
359 }) {
360 return
361 }
362 }
363 if !yield(fantasy.StreamPart{
364 Type: fantasy.StreamPartTypeTextDelta,
365 ID: "0",
366 Delta: choice.Delta.Content,
367 }) {
368 return
369 }
370 case len(choice.Delta.ToolCalls) > 0:
371 if isActiveText {
372 isActiveText = false
373 if !yield(fantasy.StreamPart{
374 Type: fantasy.StreamPartTypeTextEnd,
375 ID: "0",
376 }) {
377 return
378 }
379 }
380
381 for _, toolCallDelta := range choice.Delta.ToolCalls {
382 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
383 if existingToolCall.hasFinished {
384 continue
385 }
386 if toolCallDelta.Function.Arguments != "" {
387 existingToolCall.arguments += toolCallDelta.Function.Arguments
388 }
389 if !yield(fantasy.StreamPart{
390 Type: fantasy.StreamPartTypeToolInputDelta,
391 ID: existingToolCall.id,
392 Delta: toolCallDelta.Function.Arguments,
393 }) {
394 return
395 }
396 toolCalls[toolCallDelta.Index] = existingToolCall
397 if xjson.IsValid(existingToolCall.arguments) {
398 if !yield(fantasy.StreamPart{
399 Type: fantasy.StreamPartTypeToolInputEnd,
400 ID: existingToolCall.id,
401 }) {
402 return
403 }
404
405 if !yield(fantasy.StreamPart{
406 Type: fantasy.StreamPartTypeToolCall,
407 ID: existingToolCall.id,
408 ToolCallName: existingToolCall.name,
409 ToolCallInput: existingToolCall.arguments,
410 }) {
411 return
412 }
413 existingToolCall.hasFinished = true
414 toolCalls[toolCallDelta.Index] = existingToolCall
415 }
416 } else {
417 // Some provider like Ollama may send empty tool calls or miss some fields.
418 // We'll skip when we don't have enough info and also assume sane defaults.
419 if toolCallDelta.Function.Name == "" && toolCallDelta.Function.Arguments == "" {
420 continue
421 }
422 toolCallDelta.Type = cmp.Or(toolCallDelta.Type, "function")
423 toolCallDelta.ID = cmp.Or(toolCallDelta.ID, fmt.Sprintf("tool-call-%d", toolCallDelta.Index))
424
425 if toolCallDelta.Type != "function" {
426 yield(fantasy.StreamPart{
427 Type: fantasy.StreamPartTypeError,
428 Error: &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."},
429 })
430 return
431 }
432
433 if !yield(fantasy.StreamPart{
434 Type: fantasy.StreamPartTypeToolInputStart,
435 ID: toolCallDelta.ID,
436 ToolCallName: toolCallDelta.Function.Name,
437 }) {
438 return
439 }
440 toolCalls[toolCallDelta.Index] = streamToolCall{
441 id: toolCallDelta.ID,
442 name: toolCallDelta.Function.Name,
443 arguments: toolCallDelta.Function.Arguments,
444 }
445
446 exTc := toolCalls[toolCallDelta.Index]
447 if exTc.arguments != "" {
448 if !yield(fantasy.StreamPart{
449 Type: fantasy.StreamPartTypeToolInputDelta,
450 ID: exTc.id,
451 Delta: exTc.arguments,
452 }) {
453 return
454 }
455 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
456 if !yield(fantasy.StreamPart{
457 Type: fantasy.StreamPartTypeToolInputEnd,
458 ID: exTc.id,
459 }) {
460 return
461 }
462
463 if !yield(fantasy.StreamPart{
464 Type: fantasy.StreamPartTypeToolCall,
465 ID: exTc.id,
466 ToolCallName: exTc.name,
467 ToolCallInput: exTc.arguments,
468 }) {
469 return
470 }
471 exTc.hasFinished = true
472 toolCalls[toolCallDelta.Index] = exTc
473 }
474 }
475 continue
476 }
477 }
478 }
479
480 if o.streamExtraFunc != nil {
481 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
482 if !shouldContinue {
483 return
484 }
485 extraContext = updatedContext
486 }
487 }
488
489 for _, choice := range chunk.Choices {
490 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
491 for _, annotation := range annotations {
492 if annotation.Type == "url_citation" {
493 if !yield(fantasy.StreamPart{
494 Type: fantasy.StreamPartTypeSource,
495 ID: uuid.NewString(),
496 SourceType: fantasy.SourceTypeURL,
497 URL: annotation.URLCitation.URL,
498 Title: annotation.URLCitation.Title,
499 }) {
500 return
501 }
502 }
503 }
504 }
505 }
506 }
507 err := stream.Err()
508 if err == nil || errors.Is(err, io.EOF) {
509 if isActiveText {
510 isActiveText = false
511 if !yield(fantasy.StreamPart{
512 Type: fantasy.StreamPartTypeTextEnd,
513 ID: "0",
514 }) {
515 return
516 }
517 }
518
519 // Handle tool calls that finish with empty arguments (e.g., Copilot).
520 // Normalize empty args to "{}" and emit the tool call.
521 // If the arguments are invalid JSON, we still yield the tool call
522 // so the consumer (agent) can handle the error rather than
523 // silently dropping it.
524 for idx, tc := range toolCalls {
525 if tc.hasFinished {
526 continue
527 }
528 if tc.arguments == "" {
529 tc.arguments = "{}"
530 toolCalls[idx] = tc
531 }
532 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) {
533 return
534 }
535 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) {
536 return
537 }
538 tc.hasFinished = true
539 toolCalls[idx] = tc
540 }
541
542 if len(acc.Choices) > 0 {
543 choice := acc.Choices[0]
544 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
545
546 for _, annotation := range choice.Message.Annotations {
547 if annotation.Type == "url_citation" {
548 if !yield(fantasy.StreamPart{
549 Type: fantasy.StreamPartTypeSource,
550 ID: acc.ID,
551 SourceType: fantasy.SourceTypeURL,
552 URL: annotation.URLCitation.URL,
553 Title: annotation.URLCitation.Title,
554 }) {
555 return
556 }
557 }
558 }
559 }
560 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
561 if len(acc.Choices) > 0 {
562 choice := acc.Choices[0]
563 if len(choice.Message.ToolCalls) > 0 {
564 mappedFinishReason = fantasy.FinishReasonToolCalls
565 }
566 }
567 yield(fantasy.StreamPart{
568 Type: fantasy.StreamPartTypeFinish,
569 Usage: usage,
570 FinishReason: mappedFinishReason,
571 ProviderMetadata: providerMetadata,
572 })
573 return
574 } else { //nolint: revive
575 yield(fantasy.StreamPart{
576 Type: fantasy.StreamPartTypeError,
577 Error: toProviderErr(err),
578 })
579 return
580 }
581 }, nil
582}
583
584func isReasoningModel(modelID string) bool {
585 return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
586 strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
587 strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
588 strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
589 strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
590}
591
592func isSearchPreviewModel(modelID string) bool {
593 return strings.Contains(modelID, "search-preview")
594}
595
596func supportsFlexProcessing(modelID string) bool {
597 return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
598 strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
599}
600
601func supportsPriorityProcessing(modelID string) bool {
602 return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
603 strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
604 strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
605}
606
607func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
608 for _, tool := range tools {
609 if tool.GetType() == fantasy.ToolTypeFunction {
610 ft, ok := tool.(fantasy.FunctionTool)
611 if !ok {
612 continue
613 }
614 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
615 OfFunction: &openai.ChatCompletionFunctionToolParam{
616 Function: shared.FunctionDefinitionParam{
617 Name: ft.Name,
618 Description: param.NewOpt(ft.Description),
619 Parameters: openai.FunctionParameters(ft.InputSchema),
620 Strict: param.NewOpt(false),
621 },
622 Type: "function",
623 },
624 })
625 continue
626 }
627
628 warnings = append(warnings, fantasy.CallWarning{
629 Type: fantasy.CallWarningTypeUnsupportedTool,
630 Tool: tool,
631 Message: "tool is not supported",
632 })
633 }
634 if toolChoice == nil {
635 return openAiTools, openAiToolChoice, warnings
636 }
637
638 switch *toolChoice {
639 case fantasy.ToolChoiceAuto:
640 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
641 OfAuto: param.NewOpt("auto"),
642 }
643 case fantasy.ToolChoiceNone:
644 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
645 OfAuto: param.NewOpt("none"),
646 }
647 case fantasy.ToolChoiceRequired:
648 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
649 OfAuto: param.NewOpt("required"),
650 }
651 default:
652 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
653 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
654 Type: "function",
655 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
656 Name: string(*toolChoice),
657 },
658 },
659 }
660 }
661 return openAiTools, openAiToolChoice, warnings
662}
663
664// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
665func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
666 var annotations []openai.ChatCompletionMessageAnnotation
667
668 // Parse the raw JSON to extract annotations
669 var deltaData map[string]any
670 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
671 return annotations
672 }
673
674 // Check if annotations exist in the delta
675 if annotationsData, ok := deltaData["annotations"].([]any); ok {
676 for _, annotationData := range annotationsData {
677 if annotationMap, ok := annotationData.(map[string]any); ok {
678 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
679 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
680 url, urlOk := urlCitationData["url"].(string)
681 title, titleOk := urlCitationData["title"].(string)
682 if urlOk && titleOk {
683 annotation := openai.ChatCompletionMessageAnnotation{
684 Type: "url_citation",
685 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
686 URL: url,
687 Title: title,
688 },
689 }
690 annotations = append(annotations, annotation)
691 }
692 }
693 }
694 }
695 }
696 }
697
698 return annotations
699}
700
701// GenerateObject implements fantasy.LanguageModel.
702func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
703 switch o.objectMode {
704 case fantasy.ObjectModeText:
705 return object.GenerateWithText(ctx, o, call)
706 case fantasy.ObjectModeTool:
707 return object.GenerateWithTool(ctx, o, call)
708 default:
709 return o.generateObjectWithJSONMode(ctx, call)
710 }
711}
712
713// StreamObject implements fantasy.LanguageModel.
714func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
715 switch o.objectMode {
716 case fantasy.ObjectModeTool:
717 return object.StreamWithTool(ctx, o, call)
718 case fantasy.ObjectModeText:
719 return object.StreamWithText(ctx, o, call)
720 default:
721 return o.streamObjectWithJSONMode(ctx, call)
722 }
723}
724
725func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
726 jsonSchemaMap := schema.ToMap(call.Schema)
727
728 addAdditionalPropertiesFalse(jsonSchemaMap)
729
730 schemaName := call.SchemaName
731 if schemaName == "" {
732 schemaName = "response"
733 }
734
735 fantasyCall := fantasy.Call{
736 Prompt: call.Prompt,
737 MaxOutputTokens: call.MaxOutputTokens,
738 Temperature: call.Temperature,
739 TopP: call.TopP,
740 PresencePenalty: call.PresencePenalty,
741 FrequencyPenalty: call.FrequencyPenalty,
742 ProviderOptions: call.ProviderOptions,
743 }
744
745 params, warnings, err := o.prepareParams(fantasyCall)
746 if err != nil {
747 return nil, err
748 }
749
750 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
751 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
752 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
753 Name: schemaName,
754 Description: param.NewOpt(call.SchemaDescription),
755 Schema: jsonSchemaMap,
756 Strict: param.NewOpt(true),
757 },
758 },
759 }
760
761 response, err := o.client.Chat.Completions.New(ctx, *params, objectCallUARequestOptions(call)...)
762 if err != nil {
763 return nil, toProviderErr(err)
764 }
765 if len(response.Choices) == 0 {
766 usage, _ := o.usageFunc(*response)
767 return nil, &fantasy.NoObjectGeneratedError{
768 RawText: "",
769 ParseError: fmt.Errorf("no choices in response"),
770 Usage: usage,
771 FinishReason: fantasy.FinishReasonUnknown,
772 }
773 }
774
775 choice := response.Choices[0]
776 jsonText := choice.Message.Content
777
778 var obj any
779 if call.RepairText != nil {
780 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
781 } else {
782 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
783 }
784
785 usage, _ := o.usageFunc(*response)
786 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
787
788 if err != nil {
789 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
790 nogErr.Usage = usage
791 nogErr.FinishReason = finishReason
792 }
793 return nil, err
794 }
795
796 return &fantasy.ObjectResponse{
797 Object: obj,
798 RawText: jsonText,
799 Usage: usage,
800 FinishReason: finishReason,
801 Warnings: warnings,
802 }, nil
803}
804
805func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
806 jsonSchemaMap := schema.ToMap(call.Schema)
807
808 addAdditionalPropertiesFalse(jsonSchemaMap)
809
810 schemaName := call.SchemaName
811 if schemaName == "" {
812 schemaName = "response"
813 }
814
815 fantasyCall := fantasy.Call{
816 Prompt: call.Prompt,
817 MaxOutputTokens: call.MaxOutputTokens,
818 Temperature: call.Temperature,
819 TopP: call.TopP,
820 PresencePenalty: call.PresencePenalty,
821 FrequencyPenalty: call.FrequencyPenalty,
822 ProviderOptions: call.ProviderOptions,
823 }
824
825 params, warnings, err := o.prepareParams(fantasyCall)
826 if err != nil {
827 return nil, err
828 }
829
830 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
831 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
832 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
833 Name: schemaName,
834 Description: param.NewOpt(call.SchemaDescription),
835 Schema: jsonSchemaMap,
836 Strict: param.NewOpt(true),
837 },
838 },
839 }
840
841 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
842 IncludeUsage: openai.Bool(true),
843 }
844
845 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...)
846
847 return func(yield func(fantasy.ObjectStreamPart) bool) {
848 if len(warnings) > 0 {
849 if !yield(fantasy.ObjectStreamPart{
850 Type: fantasy.ObjectStreamPartTypeObject,
851 Warnings: warnings,
852 }) {
853 return
854 }
855 }
856
857 var accumulated string
858 var lastParsedObject any
859 var usage fantasy.Usage
860 var finishReason fantasy.FinishReason
861 var providerMetadata fantasy.ProviderMetadata
862 var streamErr error
863
864 for stream.Next() {
865 chunk := stream.Current()
866
867 // Update usage
868 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
869
870 if len(chunk.Choices) == 0 {
871 continue
872 }
873
874 choice := chunk.Choices[0]
875 if choice.FinishReason != "" {
876 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
877 }
878
879 if choice.Delta.Content != "" {
880 accumulated += choice.Delta.Content
881
882 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
883
884 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
885 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
886 if !reflect.DeepEqual(obj, lastParsedObject) {
887 if !yield(fantasy.ObjectStreamPart{
888 Type: fantasy.ObjectStreamPartTypeObject,
889 Object: obj,
890 }) {
891 return
892 }
893 lastParsedObject = obj
894 }
895 }
896 }
897
898 if state == schema.ParseStateFailed && call.RepairText != nil {
899 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
900 if repairErr == nil {
901 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
902 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
903 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
904 if !reflect.DeepEqual(obj2, lastParsedObject) {
905 if !yield(fantasy.ObjectStreamPart{
906 Type: fantasy.ObjectStreamPartTypeObject,
907 Object: obj2,
908 }) {
909 return
910 }
911 lastParsedObject = obj2
912 }
913 }
914 }
915 }
916 }
917 }
918
919 err := stream.Err()
920 if err != nil && !errors.Is(err, io.EOF) {
921 streamErr = toProviderErr(err)
922 yield(fantasy.ObjectStreamPart{
923 Type: fantasy.ObjectStreamPartTypeError,
924 Error: streamErr,
925 })
926 return
927 }
928
929 if lastParsedObject != nil {
930 yield(fantasy.ObjectStreamPart{
931 Type: fantasy.ObjectStreamPartTypeFinish,
932 Usage: usage,
933 FinishReason: finishReason,
934 ProviderMetadata: providerMetadata,
935 })
936 } else {
937 yield(fantasy.ObjectStreamPart{
938 Type: fantasy.ObjectStreamPartTypeError,
939 Error: &fantasy.NoObjectGeneratedError{
940 RawText: accumulated,
941 ParseError: fmt.Errorf("no valid object generated in stream"),
942 Usage: usage,
943 FinishReason: finishReason,
944 },
945 })
946 }
947 }, nil
948}
949
950// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
951// This is required by OpenAI's strict mode for structured outputs.
952func addAdditionalPropertiesFalse(schema map[string]any) {
953 if schema["type"] == "object" {
954 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
955 schema["additionalProperties"] = false
956 }
957
958 // Recursively process nested properties
959 if properties, ok := schema["properties"].(map[string]any); ok {
960 for _, propValue := range properties {
961 if propSchema, ok := propValue.(map[string]any); ok {
962 addAdditionalPropertiesFalse(propSchema)
963 }
964 }
965 }
966 }
967
968 // Handle array items
969 if items, ok := schema["items"].(map[string]any); ok {
970 addAdditionalPropertiesFalse(items)
971 }
972}