1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "reflect"
11 "strings"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/object"
15 "charm.land/fantasy/schema"
16 "github.com/charmbracelet/openai-go"
17 "github.com/charmbracelet/openai-go/packages/param"
18 "github.com/charmbracelet/openai-go/shared"
19 xjson "github.com/charmbracelet/x/json"
20 "github.com/google/uuid"
21)
22
23type languageModel struct {
24 provider string
25 modelID string
26 client openai.Client
27 objectMode fantasy.ObjectMode
28 prepareCallFunc LanguageModelPrepareCallFunc
29 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
30 extraContentFunc LanguageModelExtraContentFunc
31 usageFunc LanguageModelUsageFunc
32 streamUsageFunc LanguageModelStreamUsageFunc
33 streamExtraFunc LanguageModelStreamExtraFunc
34 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
35 toPromptFunc LanguageModelToPromptFunc
36}
37
38// LanguageModelOption is a function that configures a languageModel.
39type LanguageModelOption = func(*languageModel)
40
41// WithLanguageModelPrepareCallFunc sets the prepare call function for the language model.
42func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
43 return func(l *languageModel) {
44 l.prepareCallFunc = fn
45 }
46}
47
48// WithLanguageModelMapFinishReasonFunc sets the map finish reason function for the language model.
49func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
50 return func(l *languageModel) {
51 l.mapFinishReasonFunc = fn
52 }
53}
54
55// WithLanguageModelExtraContentFunc sets the extra content function for the language model.
56func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
57 return func(l *languageModel) {
58 l.extraContentFunc = fn
59 }
60}
61
62// WithLanguageModelStreamExtraFunc sets the stream extra function for the language model.
63func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
64 return func(l *languageModel) {
65 l.streamExtraFunc = fn
66 }
67}
68
69// WithLanguageModelUsageFunc sets the usage function for the language model.
70func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
71 return func(l *languageModel) {
72 l.usageFunc = fn
73 }
74}
75
76// WithLanguageModelStreamUsageFunc sets the stream usage function for the language model.
77func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
78 return func(l *languageModel) {
79 l.streamUsageFunc = fn
80 }
81}
82
83// WithLanguageModelToPromptFunc sets the to prompt function for the language model.
84func WithLanguageModelToPromptFunc(fn LanguageModelToPromptFunc) LanguageModelOption {
85 return func(l *languageModel) {
86 l.toPromptFunc = fn
87 }
88}
89
90// WithLanguageModelObjectMode sets the object generation mode.
91func WithLanguageModelObjectMode(om fantasy.ObjectMode) LanguageModelOption {
92 return func(l *languageModel) {
93 // not supported
94 if om == fantasy.ObjectModeJSON {
95 om = fantasy.ObjectModeAuto
96 }
97 l.objectMode = om
98 }
99}
100
101func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
102 model := languageModel{
103 modelID: modelID,
104 provider: provider,
105 client: client,
106 objectMode: fantasy.ObjectModeAuto,
107 prepareCallFunc: DefaultPrepareCallFunc,
108 mapFinishReasonFunc: DefaultMapFinishReasonFunc,
109 usageFunc: DefaultUsageFunc,
110 streamUsageFunc: DefaultStreamUsageFunc,
111 streamProviderMetadataFunc: DefaultStreamProviderMetadataFunc,
112 toPromptFunc: DefaultToPrompt,
113 }
114
115 for _, o := range opts {
116 o(&model)
117 }
118 return model
119}
120
121type streamToolCall struct {
122 id string
123 name string
124 arguments string
125 hasFinished bool
126}
127
128// Model implements fantasy.LanguageModel.
129func (o languageModel) Model() string {
130 return o.modelID
131}
132
133// Provider implements fantasy.LanguageModel.
134func (o languageModel) Provider() string {
135 return o.provider
136}
137
138func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionNewParams, []fantasy.CallWarning, error) {
139 params := &openai.ChatCompletionNewParams{}
140 messages, warnings := o.toPromptFunc(call.Prompt, o.provider, o.modelID)
141 if call.TopK != nil {
142 warnings = append(warnings, fantasy.CallWarning{
143 Type: fantasy.CallWarningTypeUnsupportedSetting,
144 Setting: "top_k",
145 })
146 }
147
148 if call.MaxOutputTokens != nil {
149 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
150 }
151 if call.Temperature != nil {
152 params.Temperature = param.NewOpt(*call.Temperature)
153 }
154 if call.TopP != nil {
155 params.TopP = param.NewOpt(*call.TopP)
156 }
157 if call.FrequencyPenalty != nil {
158 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
159 }
160 if call.PresencePenalty != nil {
161 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
162 }
163
164 if isReasoningModel(o.modelID) {
165 // remove unsupported settings for reasoning models
166 // see https://platform.openai.com/docs/guides/reasoning#limitations
167 if call.Temperature != nil {
168 params.Temperature = param.Opt[float64]{}
169 warnings = append(warnings, fantasy.CallWarning{
170 Type: fantasy.CallWarningTypeUnsupportedSetting,
171 Setting: "temperature",
172 Details: "temperature is not supported for reasoning models",
173 })
174 }
175 if call.TopP != nil {
176 params.TopP = param.Opt[float64]{}
177 warnings = append(warnings, fantasy.CallWarning{
178 Type: fantasy.CallWarningTypeUnsupportedSetting,
179 Setting: "TopP",
180 Details: "TopP is not supported for reasoning models",
181 })
182 }
183 if call.FrequencyPenalty != nil {
184 params.FrequencyPenalty = param.Opt[float64]{}
185 warnings = append(warnings, fantasy.CallWarning{
186 Type: fantasy.CallWarningTypeUnsupportedSetting,
187 Setting: "FrequencyPenalty",
188 Details: "FrequencyPenalty is not supported for reasoning models",
189 })
190 }
191 if call.PresencePenalty != nil {
192 params.PresencePenalty = param.Opt[float64]{}
193 warnings = append(warnings, fantasy.CallWarning{
194 Type: fantasy.CallWarningTypeUnsupportedSetting,
195 Setting: "PresencePenalty",
196 Details: "PresencePenalty is not supported for reasoning models",
197 })
198 }
199
200 // reasoning models use max_completion_tokens instead of max_tokens
201 if call.MaxOutputTokens != nil {
202 if !params.MaxCompletionTokens.Valid() {
203 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
204 }
205 params.MaxTokens = param.Opt[int64]{}
206 }
207 }
208
209 // Handle search preview models
210 if isSearchPreviewModel(o.modelID) {
211 if call.Temperature != nil {
212 params.Temperature = param.Opt[float64]{}
213 warnings = append(warnings, fantasy.CallWarning{
214 Type: fantasy.CallWarningTypeUnsupportedSetting,
215 Setting: "temperature",
216 Details: "temperature is not supported for the search preview models and has been removed.",
217 })
218 }
219 }
220
221 optionsWarnings, err := o.prepareCallFunc(o, params, call)
222 if err != nil {
223 return nil, nil, err
224 }
225
226 if len(optionsWarnings) > 0 {
227 warnings = append(warnings, optionsWarnings...)
228 }
229
230 params.Messages = messages
231 params.Model = o.modelID
232
233 if len(call.Tools) > 0 {
234 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
235 params.Tools = tools
236 if toolChoice != nil {
237 params.ToolChoice = *toolChoice
238 }
239 warnings = append(warnings, toolWarnings...)
240 }
241 return params, warnings, nil
242}
243
244// Generate implements fantasy.LanguageModel.
245func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
246 params, warnings, err := o.prepareParams(call)
247 if err != nil {
248 return nil, err
249 }
250 response, err := o.client.Chat.Completions.New(ctx, *params, callUARequestOptions(call)...)
251 if err != nil {
252 return nil, toProviderErr(err)
253 }
254
255 if len(response.Choices) == 0 {
256 return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
257 }
258 choice := response.Choices[0]
259 content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
260 text := choice.Message.Content
261 if text != "" {
262 content = append(content, fantasy.TextContent{
263 Text: text,
264 })
265 }
266 if o.extraContentFunc != nil {
267 extraContent := o.extraContentFunc(choice)
268 content = append(content, extraContent...)
269 }
270 for _, tc := range choice.Message.ToolCalls {
271 toolCallID := tc.ID
272 content = append(content, fantasy.ToolCallContent{
273 ProviderExecuted: false,
274 ToolCallID: toolCallID,
275 ToolName: tc.Function.Name,
276 Input: tc.Function.Arguments,
277 })
278 }
279 for _, annotation := range choice.Message.Annotations {
280 if annotation.Type == "url_citation" {
281 content = append(content, fantasy.SourceContent{
282 SourceType: fantasy.SourceTypeURL,
283 ID: uuid.NewString(),
284 URL: annotation.URLCitation.URL,
285 Title: annotation.URLCitation.Title,
286 })
287 }
288 }
289
290 usage, providerMetadata := o.usageFunc(*response)
291
292 mappedFinishReason := o.mapFinishReasonFunc(choice.FinishReason)
293 if len(choice.Message.ToolCalls) > 0 {
294 mappedFinishReason = fantasy.FinishReasonToolCalls
295 }
296 return &fantasy.Response{
297 Content: content,
298 Usage: usage,
299 FinishReason: mappedFinishReason,
300 ProviderMetadata: fantasy.ProviderMetadata{
301 Name: providerMetadata,
302 },
303 Warnings: warnings,
304 }, nil
305}
306
307// Stream implements fantasy.LanguageModel.
308func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
309 params, warnings, err := o.prepareParams(call)
310 if err != nil {
311 return nil, err
312 }
313
314 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
315 IncludeUsage: openai.Bool(true),
316 }
317
318 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, callUARequestOptions(call)...)
319 isActiveText := false
320 toolCalls := make(map[int64]streamToolCall)
321
322 providerMetadata := fantasy.ProviderMetadata{
323 Name: &ProviderMetadata{},
324 }
325 acc := openai.ChatCompletionAccumulator{}
326 extraContext := make(map[string]any)
327 var usage fantasy.Usage
328 var finishReason string
329 return func(yield func(fantasy.StreamPart) bool) {
330 if len(warnings) > 0 {
331 if !yield(fantasy.StreamPart{
332 Type: fantasy.StreamPartTypeWarnings,
333 Warnings: warnings,
334 }) {
335 return
336 }
337 }
338 for stream.Next() {
339 chunk := stream.Current()
340 acc.AddChunk(chunk)
341 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
342 if len(chunk.Choices) == 0 {
343 continue
344 }
345 for _, choice := range chunk.Choices {
346 if choice.FinishReason != "" {
347 finishReason = choice.FinishReason
348 }
349 switch {
350 case choice.Delta.Content != "":
351 if !isActiveText {
352 isActiveText = true
353 if !yield(fantasy.StreamPart{
354 Type: fantasy.StreamPartTypeTextStart,
355 ID: "0",
356 }) {
357 return
358 }
359 }
360 if !yield(fantasy.StreamPart{
361 Type: fantasy.StreamPartTypeTextDelta,
362 ID: "0",
363 Delta: choice.Delta.Content,
364 }) {
365 return
366 }
367 case len(choice.Delta.ToolCalls) > 0:
368 if isActiveText {
369 isActiveText = false
370 if !yield(fantasy.StreamPart{
371 Type: fantasy.StreamPartTypeTextEnd,
372 ID: "0",
373 }) {
374 return
375 }
376 }
377
378 for _, toolCallDelta := range choice.Delta.ToolCalls {
379 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
380 if existingToolCall.hasFinished {
381 continue
382 }
383 if toolCallDelta.Function.Arguments != "" {
384 existingToolCall.arguments += toolCallDelta.Function.Arguments
385 }
386 if !yield(fantasy.StreamPart{
387 Type: fantasy.StreamPartTypeToolInputDelta,
388 ID: existingToolCall.id,
389 Delta: toolCallDelta.Function.Arguments,
390 }) {
391 return
392 }
393 toolCalls[toolCallDelta.Index] = existingToolCall
394 if xjson.IsValid(existingToolCall.arguments) {
395 if !yield(fantasy.StreamPart{
396 Type: fantasy.StreamPartTypeToolInputEnd,
397 ID: existingToolCall.id,
398 }) {
399 return
400 }
401
402 if !yield(fantasy.StreamPart{
403 Type: fantasy.StreamPartTypeToolCall,
404 ID: existingToolCall.id,
405 ToolCallName: existingToolCall.name,
406 ToolCallInput: existingToolCall.arguments,
407 }) {
408 return
409 }
410 existingToolCall.hasFinished = true
411 toolCalls[toolCallDelta.Index] = existingToolCall
412 }
413 } else {
414 // Some provider like Ollama may send empty tool calls or miss some fields.
415 // We'll skip when we don't have enough info and also assume sane defaults.
416 if toolCallDelta.Function.Name == "" && toolCallDelta.Function.Arguments == "" {
417 continue
418 }
419 toolCallDelta.Type = cmp.Or(toolCallDelta.Type, "function")
420 toolCallDelta.ID = cmp.Or(toolCallDelta.ID, fmt.Sprintf("tool-call-%d", toolCallDelta.Index))
421
422 if toolCallDelta.Type != "function" {
423 yield(fantasy.StreamPart{
424 Type: fantasy.StreamPartTypeError,
425 Error: &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."},
426 })
427 return
428 }
429
430 if !yield(fantasy.StreamPart{
431 Type: fantasy.StreamPartTypeToolInputStart,
432 ID: toolCallDelta.ID,
433 ToolCallName: toolCallDelta.Function.Name,
434 }) {
435 return
436 }
437 toolCalls[toolCallDelta.Index] = streamToolCall{
438 id: toolCallDelta.ID,
439 name: toolCallDelta.Function.Name,
440 arguments: toolCallDelta.Function.Arguments,
441 }
442
443 exTc := toolCalls[toolCallDelta.Index]
444 if exTc.arguments != "" {
445 if !yield(fantasy.StreamPart{
446 Type: fantasy.StreamPartTypeToolInputDelta,
447 ID: exTc.id,
448 Delta: exTc.arguments,
449 }) {
450 return
451 }
452 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
453 if !yield(fantasy.StreamPart{
454 Type: fantasy.StreamPartTypeToolInputEnd,
455 ID: exTc.id,
456 }) {
457 return
458 }
459
460 if !yield(fantasy.StreamPart{
461 Type: fantasy.StreamPartTypeToolCall,
462 ID: exTc.id,
463 ToolCallName: exTc.name,
464 ToolCallInput: exTc.arguments,
465 }) {
466 return
467 }
468 exTc.hasFinished = true
469 toolCalls[toolCallDelta.Index] = exTc
470 }
471 }
472 continue
473 }
474 }
475 }
476
477 if o.streamExtraFunc != nil {
478 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
479 if !shouldContinue {
480 return
481 }
482 extraContext = updatedContext
483 }
484 }
485
486 for _, choice := range chunk.Choices {
487 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
488 for _, annotation := range annotations {
489 if annotation.Type == "url_citation" {
490 if !yield(fantasy.StreamPart{
491 Type: fantasy.StreamPartTypeSource,
492 ID: uuid.NewString(),
493 SourceType: fantasy.SourceTypeURL,
494 URL: annotation.URLCitation.URL,
495 Title: annotation.URLCitation.Title,
496 }) {
497 return
498 }
499 }
500 }
501 }
502 }
503 }
504 err := stream.Err()
505 if err == nil || errors.Is(err, io.EOF) {
506 if isActiveText {
507 isActiveText = false
508 if !yield(fantasy.StreamPart{
509 Type: fantasy.StreamPartTypeTextEnd,
510 ID: "0",
511 }) {
512 return
513 }
514 }
515
516 // Handle tool calls that finish with empty arguments (e.g., Copilot).
517 // Normalize empty args to "{}" and emit the tool call if valid.
518 for idx, tc := range toolCalls {
519 if tc.hasFinished {
520 continue
521 }
522 if tc.arguments == "" {
523 tc.arguments = "{}"
524 toolCalls[idx] = tc
525 }
526 if xjson.IsValid(tc.arguments) {
527 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolInputEnd, ID: tc.id}) {
528 return
529 }
530 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeToolCall, ID: tc.id, ToolCallName: tc.name, ToolCallInput: tc.arguments}) {
531 return
532 }
533 tc.hasFinished = true
534 toolCalls[idx] = tc
535 }
536 }
537
538 if len(acc.Choices) > 0 {
539 choice := acc.Choices[0]
540 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
541
542 for _, annotation := range choice.Message.Annotations {
543 if annotation.Type == "url_citation" {
544 if !yield(fantasy.StreamPart{
545 Type: fantasy.StreamPartTypeSource,
546 ID: acc.ID,
547 SourceType: fantasy.SourceTypeURL,
548 URL: annotation.URLCitation.URL,
549 Title: annotation.URLCitation.Title,
550 }) {
551 return
552 }
553 }
554 }
555 }
556 mappedFinishReason := o.mapFinishReasonFunc(finishReason)
557 if len(acc.Choices) > 0 {
558 choice := acc.Choices[0]
559 if len(choice.Message.ToolCalls) > 0 {
560 mappedFinishReason = fantasy.FinishReasonToolCalls
561 }
562 }
563 yield(fantasy.StreamPart{
564 Type: fantasy.StreamPartTypeFinish,
565 Usage: usage,
566 FinishReason: mappedFinishReason,
567 ProviderMetadata: providerMetadata,
568 })
569 return
570 } else { //nolint: revive
571 yield(fantasy.StreamPart{
572 Type: fantasy.StreamPartTypeError,
573 Error: toProviderErr(err),
574 })
575 return
576 }
577 }, nil
578}
579
580func isReasoningModel(modelID string) bool {
581 return strings.HasPrefix(modelID, "o1") || strings.Contains(modelID, "-o1") ||
582 strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
583 strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") ||
584 strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") ||
585 strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "gpt-5-chat")
586}
587
588func isSearchPreviewModel(modelID string) bool {
589 return strings.Contains(modelID, "search-preview")
590}
591
592func supportsFlexProcessing(modelID string) bool {
593 return strings.HasPrefix(modelID, "o3") || strings.Contains(modelID, "-o3") ||
594 strings.Contains(modelID, "o4-mini") || strings.Contains(modelID, "gpt-5")
595}
596
597func supportsPriorityProcessing(modelID string) bool {
598 return strings.Contains(modelID, "gpt-4") || strings.Contains(modelID, "gpt-5") ||
599 strings.Contains(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
600 strings.Contains(modelID, "-o3") || strings.Contains(modelID, "o4-mini")
601}
602
603func toOpenAiTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []fantasy.CallWarning) {
604 for _, tool := range tools {
605 if tool.GetType() == fantasy.ToolTypeFunction {
606 ft, ok := tool.(fantasy.FunctionTool)
607 if !ok {
608 continue
609 }
610 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
611 OfFunction: &openai.ChatCompletionFunctionToolParam{
612 Function: shared.FunctionDefinitionParam{
613 Name: ft.Name,
614 Description: param.NewOpt(ft.Description),
615 Parameters: openai.FunctionParameters(ft.InputSchema),
616 Strict: param.NewOpt(false),
617 },
618 Type: "function",
619 },
620 })
621 continue
622 }
623
624 warnings = append(warnings, fantasy.CallWarning{
625 Type: fantasy.CallWarningTypeUnsupportedTool,
626 Tool: tool,
627 Message: "tool is not supported",
628 })
629 }
630 if toolChoice == nil {
631 return openAiTools, openAiToolChoice, warnings
632 }
633
634 switch *toolChoice {
635 case fantasy.ToolChoiceAuto:
636 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
637 OfAuto: param.NewOpt("auto"),
638 }
639 case fantasy.ToolChoiceNone:
640 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
641 OfAuto: param.NewOpt("none"),
642 }
643 case fantasy.ToolChoiceRequired:
644 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
645 OfAuto: param.NewOpt("required"),
646 }
647 default:
648 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
649 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
650 Type: "function",
651 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
652 Name: string(*toolChoice),
653 },
654 },
655 }
656 }
657 return openAiTools, openAiToolChoice, warnings
658}
659
660// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
661func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
662 var annotations []openai.ChatCompletionMessageAnnotation
663
664 // Parse the raw JSON to extract annotations
665 var deltaData map[string]any
666 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
667 return annotations
668 }
669
670 // Check if annotations exist in the delta
671 if annotationsData, ok := deltaData["annotations"].([]any); ok {
672 for _, annotationData := range annotationsData {
673 if annotationMap, ok := annotationData.(map[string]any); ok {
674 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
675 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
676 url, urlOk := urlCitationData["url"].(string)
677 title, titleOk := urlCitationData["title"].(string)
678 if urlOk && titleOk {
679 annotation := openai.ChatCompletionMessageAnnotation{
680 Type: "url_citation",
681 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
682 URL: url,
683 Title: title,
684 },
685 }
686 annotations = append(annotations, annotation)
687 }
688 }
689 }
690 }
691 }
692 }
693
694 return annotations
695}
696
697// GenerateObject implements fantasy.LanguageModel.
698func (o languageModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
699 switch o.objectMode {
700 case fantasy.ObjectModeText:
701 return object.GenerateWithText(ctx, o, call)
702 case fantasy.ObjectModeTool:
703 return object.GenerateWithTool(ctx, o, call)
704 default:
705 return o.generateObjectWithJSONMode(ctx, call)
706 }
707}
708
709// StreamObject implements fantasy.LanguageModel.
710func (o languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
711 switch o.objectMode {
712 case fantasy.ObjectModeTool:
713 return object.StreamWithTool(ctx, o, call)
714 case fantasy.ObjectModeText:
715 return object.StreamWithText(ctx, o, call)
716 default:
717 return o.streamObjectWithJSONMode(ctx, call)
718 }
719}
720
721func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
722 jsonSchemaMap := schema.ToMap(call.Schema)
723
724 addAdditionalPropertiesFalse(jsonSchemaMap)
725
726 schemaName := call.SchemaName
727 if schemaName == "" {
728 schemaName = "response"
729 }
730
731 fantasyCall := fantasy.Call{
732 Prompt: call.Prompt,
733 MaxOutputTokens: call.MaxOutputTokens,
734 Temperature: call.Temperature,
735 TopP: call.TopP,
736 PresencePenalty: call.PresencePenalty,
737 FrequencyPenalty: call.FrequencyPenalty,
738 ProviderOptions: call.ProviderOptions,
739 }
740
741 params, warnings, err := o.prepareParams(fantasyCall)
742 if err != nil {
743 return nil, err
744 }
745
746 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
747 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
748 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
749 Name: schemaName,
750 Description: param.NewOpt(call.SchemaDescription),
751 Schema: jsonSchemaMap,
752 Strict: param.NewOpt(true),
753 },
754 },
755 }
756
757 response, err := o.client.Chat.Completions.New(ctx, *params, objectCallUARequestOptions(call)...)
758 if err != nil {
759 return nil, toProviderErr(err)
760 }
761 if len(response.Choices) == 0 {
762 usage, _ := o.usageFunc(*response)
763 return nil, &fantasy.NoObjectGeneratedError{
764 RawText: "",
765 ParseError: fmt.Errorf("no choices in response"),
766 Usage: usage,
767 FinishReason: fantasy.FinishReasonUnknown,
768 }
769 }
770
771 choice := response.Choices[0]
772 jsonText := choice.Message.Content
773
774 var obj any
775 if call.RepairText != nil {
776 obj, err = schema.ParseAndValidateWithRepair(ctx, jsonText, call.Schema, call.RepairText)
777 } else {
778 obj, err = schema.ParseAndValidate(jsonText, call.Schema)
779 }
780
781 usage, _ := o.usageFunc(*response)
782 finishReason := o.mapFinishReasonFunc(choice.FinishReason)
783
784 if err != nil {
785 if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
786 nogErr.Usage = usage
787 nogErr.FinishReason = finishReason
788 }
789 return nil, err
790 }
791
792 return &fantasy.ObjectResponse{
793 Object: obj,
794 RawText: jsonText,
795 Usage: usage,
796 FinishReason: finishReason,
797 Warnings: warnings,
798 }, nil
799}
800
801func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
802 jsonSchemaMap := schema.ToMap(call.Schema)
803
804 addAdditionalPropertiesFalse(jsonSchemaMap)
805
806 schemaName := call.SchemaName
807 if schemaName == "" {
808 schemaName = "response"
809 }
810
811 fantasyCall := fantasy.Call{
812 Prompt: call.Prompt,
813 MaxOutputTokens: call.MaxOutputTokens,
814 Temperature: call.Temperature,
815 TopP: call.TopP,
816 PresencePenalty: call.PresencePenalty,
817 FrequencyPenalty: call.FrequencyPenalty,
818 ProviderOptions: call.ProviderOptions,
819 }
820
821 params, warnings, err := o.prepareParams(fantasyCall)
822 if err != nil {
823 return nil, err
824 }
825
826 params.ResponseFormat = openai.ChatCompletionNewParamsResponseFormatUnion{
827 OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{
828 JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{
829 Name: schemaName,
830 Description: param.NewOpt(call.SchemaDescription),
831 Schema: jsonSchemaMap,
832 Strict: param.NewOpt(true),
833 },
834 },
835 }
836
837 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
838 IncludeUsage: openai.Bool(true),
839 }
840
841 stream := o.client.Chat.Completions.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...)
842
843 return func(yield func(fantasy.ObjectStreamPart) bool) {
844 if len(warnings) > 0 {
845 if !yield(fantasy.ObjectStreamPart{
846 Type: fantasy.ObjectStreamPartTypeObject,
847 Warnings: warnings,
848 }) {
849 return
850 }
851 }
852
853 var accumulated string
854 var lastParsedObject any
855 var usage fantasy.Usage
856 var finishReason fantasy.FinishReason
857 var providerMetadata fantasy.ProviderMetadata
858 var streamErr error
859
860 for stream.Next() {
861 chunk := stream.Current()
862
863 // Update usage
864 usage, providerMetadata = o.streamUsageFunc(chunk, make(map[string]any), providerMetadata)
865
866 if len(chunk.Choices) == 0 {
867 continue
868 }
869
870 choice := chunk.Choices[0]
871 if choice.FinishReason != "" {
872 finishReason = o.mapFinishReasonFunc(choice.FinishReason)
873 }
874
875 if choice.Delta.Content != "" {
876 accumulated += choice.Delta.Content
877
878 obj, state, parseErr := schema.ParsePartialJSON(accumulated)
879
880 if state == schema.ParseStateSuccessful || state == schema.ParseStateRepaired {
881 if err := schema.ValidateAgainstSchema(obj, call.Schema); err == nil {
882 if !reflect.DeepEqual(obj, lastParsedObject) {
883 if !yield(fantasy.ObjectStreamPart{
884 Type: fantasy.ObjectStreamPartTypeObject,
885 Object: obj,
886 }) {
887 return
888 }
889 lastParsedObject = obj
890 }
891 }
892 }
893
894 if state == schema.ParseStateFailed && call.RepairText != nil {
895 repairedText, repairErr := call.RepairText(ctx, accumulated, parseErr)
896 if repairErr == nil {
897 obj2, state2, _ := schema.ParsePartialJSON(repairedText)
898 if (state2 == schema.ParseStateSuccessful || state2 == schema.ParseStateRepaired) &&
899 schema.ValidateAgainstSchema(obj2, call.Schema) == nil {
900 if !reflect.DeepEqual(obj2, lastParsedObject) {
901 if !yield(fantasy.ObjectStreamPart{
902 Type: fantasy.ObjectStreamPartTypeObject,
903 Object: obj2,
904 }) {
905 return
906 }
907 lastParsedObject = obj2
908 }
909 }
910 }
911 }
912 }
913 }
914
915 err := stream.Err()
916 if err != nil && !errors.Is(err, io.EOF) {
917 streamErr = toProviderErr(err)
918 yield(fantasy.ObjectStreamPart{
919 Type: fantasy.ObjectStreamPartTypeError,
920 Error: streamErr,
921 })
922 return
923 }
924
925 if lastParsedObject != nil {
926 yield(fantasy.ObjectStreamPart{
927 Type: fantasy.ObjectStreamPartTypeFinish,
928 Usage: usage,
929 FinishReason: finishReason,
930 ProviderMetadata: providerMetadata,
931 })
932 } else {
933 yield(fantasy.ObjectStreamPart{
934 Type: fantasy.ObjectStreamPartTypeError,
935 Error: &fantasy.NoObjectGeneratedError{
936 RawText: accumulated,
937 ParseError: fmt.Errorf("no valid object generated in stream"),
938 Usage: usage,
939 FinishReason: finishReason,
940 },
941 })
942 }
943 }, nil
944}
945
946// addAdditionalPropertiesFalse recursively adds "additionalProperties": false to all object schemas.
947// This is required by OpenAI's strict mode for structured outputs.
948func addAdditionalPropertiesFalse(schema map[string]any) {
949 if schema["type"] == "object" {
950 if _, hasAdditional := schema["additionalProperties"]; !hasAdditional {
951 schema["additionalProperties"] = false
952 }
953
954 // Recursively process nested properties
955 if properties, ok := schema["properties"].(map[string]any); ok {
956 for _, propValue := range properties {
957 if propSchema, ok := propValue.(map[string]any); ok {
958 addAdditionalPropertiesFalse(propSchema)
959 }
960 }
961 }
962 }
963
964 // Handle array items
965 if items, ok := schema["items"].(map[string]any); ok {
966 addAdditionalPropertiesFalse(items)
967 }
968}