1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "strings"
11
12 "github.com/charmbracelet/fantasy/ai"
13 xjson "github.com/charmbracelet/x/json"
14 "github.com/google/uuid"
15 "github.com/openai/openai-go/v2"
16 "github.com/openai/openai-go/v2/packages/param"
17 "github.com/openai/openai-go/v2/shared"
18)
19
20type languageModel struct {
21 provider string
22 modelID string
23 client openai.Client
24 prepareCallFunc PrepareLanguageModelCallFunc
25}
26
27type LanguageModelOption = func(*languageModel)
28
29func WithPrepareLanguageModelCallFunc(fn PrepareLanguageModelCallFunc) LanguageModelOption {
30 return func(l *languageModel) {
31 l.prepareCallFunc = fn
32 }
33}
34
35func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
36 model := languageModel{
37 modelID: modelID,
38 provider: provider,
39 client: client,
40 prepareCallFunc: defaultPrepareLanguageModelCall,
41 }
42
43 for _, o := range opts {
44 o(&model)
45 }
46 return model
47}
48
49type streamToolCall struct {
50 id string
51 name string
52 arguments string
53 hasFinished bool
54}
55
56// Model implements ai.LanguageModel.
57func (o languageModel) Model() string {
58 return o.modelID
59}
60
61// Provider implements ai.LanguageModel.
62func (o languageModel) Provider() string {
63 return o.provider
64}
65
66func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
67 params := &openai.ChatCompletionNewParams{}
68 messages, warnings := toPrompt(call.Prompt)
69 if call.TopK != nil {
70 warnings = append(warnings, ai.CallWarning{
71 Type: ai.CallWarningTypeUnsupportedSetting,
72 Setting: "top_k",
73 })
74 }
75
76 if call.MaxOutputTokens != nil {
77 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
78 }
79 if call.Temperature != nil {
80 params.Temperature = param.NewOpt(*call.Temperature)
81 }
82 if call.TopP != nil {
83 params.TopP = param.NewOpt(*call.TopP)
84 }
85 if call.FrequencyPenalty != nil {
86 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
87 }
88 if call.PresencePenalty != nil {
89 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
90 }
91
92 if isReasoningModel(o.modelID) {
93 // remove unsupported settings for reasoning models
94 // see https://platform.openai.com/docs/guides/reasoning#limitations
95 if call.Temperature != nil {
96 params.Temperature = param.Opt[float64]{}
97 warnings = append(warnings, ai.CallWarning{
98 Type: ai.CallWarningTypeUnsupportedSetting,
99 Setting: "temperature",
100 Details: "temperature is not supported for reasoning models",
101 })
102 }
103 if call.TopP != nil {
104 params.TopP = param.Opt[float64]{}
105 warnings = append(warnings, ai.CallWarning{
106 Type: ai.CallWarningTypeUnsupportedSetting,
107 Setting: "TopP",
108 Details: "TopP is not supported for reasoning models",
109 })
110 }
111 if call.FrequencyPenalty != nil {
112 params.FrequencyPenalty = param.Opt[float64]{}
113 warnings = append(warnings, ai.CallWarning{
114 Type: ai.CallWarningTypeUnsupportedSetting,
115 Setting: "FrequencyPenalty",
116 Details: "FrequencyPenalty is not supported for reasoning models",
117 })
118 }
119 if call.PresencePenalty != nil {
120 params.PresencePenalty = param.Opt[float64]{}
121 warnings = append(warnings, ai.CallWarning{
122 Type: ai.CallWarningTypeUnsupportedSetting,
123 Setting: "PresencePenalty",
124 Details: "PresencePenalty is not supported for reasoning models",
125 })
126 }
127
128 // reasoning models use max_completion_tokens instead of max_tokens
129 if call.MaxOutputTokens != nil {
130 if !params.MaxCompletionTokens.Valid() {
131 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
132 }
133 params.MaxTokens = param.Opt[int64]{}
134 }
135 }
136
137 // Handle search preview models
138 if isSearchPreviewModel(o.modelID) {
139 if call.Temperature != nil {
140 params.Temperature = param.Opt[float64]{}
141 warnings = append(warnings, ai.CallWarning{
142 Type: ai.CallWarningTypeUnsupportedSetting,
143 Setting: "temperature",
144 Details: "temperature is not supported for the search preview models and has been removed.",
145 })
146 }
147 }
148
149 optionsWarnings, err := o.prepareCallFunc(o, params, call)
150 if err != nil {
151 return nil, nil, err
152 }
153
154 if len(optionsWarnings) > 0 {
155 warnings = append(warnings, optionsWarnings...)
156 }
157
158 params.Messages = messages
159 params.Model = o.modelID
160
161 if len(call.Tools) > 0 {
162 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
163 params.Tools = tools
164 if toolChoice != nil {
165 params.ToolChoice = *toolChoice
166 }
167 warnings = append(warnings, toolWarnings...)
168 }
169 return params, warnings, nil
170}
171
172func (o languageModel) handleError(err error) error {
173 var apiErr *openai.Error
174 if errors.As(err, &apiErr) {
175 requestDump := apiErr.DumpRequest(true)
176 responseDump := apiErr.DumpResponse(true)
177 headers := map[string]string{}
178 for k, h := range apiErr.Response.Header {
179 v := h[len(h)-1]
180 headers[strings.ToLower(k)] = v
181 }
182 return ai.NewAPICallError(
183 apiErr.Message,
184 apiErr.Request.URL.String(),
185 string(requestDump),
186 apiErr.StatusCode,
187 headers,
188 string(responseDump),
189 apiErr,
190 false,
191 )
192 }
193 return err
194}
195
196// Generate implements ai.LanguageModel.
197func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
198 params, warnings, err := o.prepareParams(call)
199 if err != nil {
200 return nil, err
201 }
202 response, err := o.client.Chat.Completions.New(ctx, *params)
203 if err != nil {
204 return nil, o.handleError(err)
205 }
206
207 if len(response.Choices) == 0 {
208 return nil, errors.New("no response generated")
209 }
210 choice := response.Choices[0]
211 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
212 text := choice.Message.Content
213 if text != "" {
214 content = append(content, ai.TextContent{
215 Text: text,
216 })
217 }
218
219 for _, tc := range choice.Message.ToolCalls {
220 toolCallID := tc.ID
221 if toolCallID == "" {
222 toolCallID = uuid.NewString()
223 }
224 content = append(content, ai.ToolCallContent{
225 ProviderExecuted: false, // TODO: update when handling other tools
226 ToolCallID: toolCallID,
227 ToolName: tc.Function.Name,
228 Input: tc.Function.Arguments,
229 })
230 }
231 // Handle annotations/citations
232 for _, annotation := range choice.Message.Annotations {
233 if annotation.Type == "url_citation" {
234 content = append(content, ai.SourceContent{
235 SourceType: ai.SourceTypeURL,
236 ID: uuid.NewString(),
237 URL: annotation.URLCitation.URL,
238 Title: annotation.URLCitation.Title,
239 })
240 }
241 }
242
243 completionTokenDetails := response.Usage.CompletionTokensDetails
244 promptTokenDetails := response.Usage.PromptTokensDetails
245
246 // Build provider metadata
247 providerMetadata := &ProviderMetadata{}
248 // Add logprobs if available
249 if len(choice.Logprobs.Content) > 0 {
250 providerMetadata.Logprobs = choice.Logprobs.Content
251 }
252
253 // Add prediction tokens if available
254 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
255 if completionTokenDetails.AcceptedPredictionTokens > 0 {
256 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
257 }
258 if completionTokenDetails.RejectedPredictionTokens > 0 {
259 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
260 }
261 }
262
263 return &ai.Response{
264 Content: content,
265 Usage: ai.Usage{
266 InputTokens: response.Usage.PromptTokens,
267 OutputTokens: response.Usage.CompletionTokens,
268 TotalTokens: response.Usage.TotalTokens,
269 ReasoningTokens: completionTokenDetails.ReasoningTokens,
270 CacheReadTokens: promptTokenDetails.CachedTokens,
271 },
272 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
273 ProviderMetadata: ai.ProviderMetadata{
274 Name: providerMetadata,
275 },
276 Warnings: warnings,
277 }, nil
278}
279
280// Stream implements ai.LanguageModel.
281func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
282 params, warnings, err := o.prepareParams(call)
283 if err != nil {
284 return nil, err
285 }
286
287 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
288 IncludeUsage: openai.Bool(true),
289 }
290
291 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
292 isActiveText := false
293 toolCalls := make(map[int64]streamToolCall)
294
295 // Build provider metadata for streaming
296 streamProviderMetadata := &ProviderMetadata{}
297 acc := openai.ChatCompletionAccumulator{}
298 var usage ai.Usage
299 return func(yield func(ai.StreamPart) bool) {
300 if len(warnings) > 0 {
301 if !yield(ai.StreamPart{
302 Type: ai.StreamPartTypeWarnings,
303 Warnings: warnings,
304 }) {
305 return
306 }
307 }
308 for stream.Next() {
309 chunk := stream.Current()
310 acc.AddChunk(chunk)
311 if chunk.Usage.TotalTokens > 0 {
312 // we do this here because the acc does not add prompt details
313 completionTokenDetails := chunk.Usage.CompletionTokensDetails
314 promptTokenDetails := chunk.Usage.PromptTokensDetails
315 usage = ai.Usage{
316 InputTokens: chunk.Usage.PromptTokens,
317 OutputTokens: chunk.Usage.CompletionTokens,
318 TotalTokens: chunk.Usage.TotalTokens,
319 ReasoningTokens: completionTokenDetails.ReasoningTokens,
320 CacheReadTokens: promptTokenDetails.CachedTokens,
321 }
322
323 // Add prediction tokens if available
324 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
325 if completionTokenDetails.AcceptedPredictionTokens > 0 {
326 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
327 }
328 if completionTokenDetails.RejectedPredictionTokens > 0 {
329 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
330 }
331 }
332 }
333 if len(chunk.Choices) == 0 {
334 continue
335 }
336 for _, choice := range chunk.Choices {
337 switch {
338 case choice.Delta.Content != "":
339 if !isActiveText {
340 isActiveText = true
341 if !yield(ai.StreamPart{
342 Type: ai.StreamPartTypeTextStart,
343 ID: "0",
344 }) {
345 return
346 }
347 }
348 if !yield(ai.StreamPart{
349 Type: ai.StreamPartTypeTextDelta,
350 ID: "0",
351 Delta: choice.Delta.Content,
352 }) {
353 return
354 }
355 case len(choice.Delta.ToolCalls) > 0:
356 if isActiveText {
357 isActiveText = false
358 if !yield(ai.StreamPart{
359 Type: ai.StreamPartTypeTextEnd,
360 ID: "0",
361 }) {
362 return
363 }
364 }
365
366 for _, toolCallDelta := range choice.Delta.ToolCalls {
367 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
368 if existingToolCall.hasFinished {
369 continue
370 }
371 if toolCallDelta.Function.Arguments != "" {
372 existingToolCall.arguments += toolCallDelta.Function.Arguments
373 }
374 if !yield(ai.StreamPart{
375 Type: ai.StreamPartTypeToolInputDelta,
376 ID: existingToolCall.id,
377 Delta: toolCallDelta.Function.Arguments,
378 }) {
379 return
380 }
381 toolCalls[toolCallDelta.Index] = existingToolCall
382 if xjson.IsValid(existingToolCall.arguments) {
383 if !yield(ai.StreamPart{
384 Type: ai.StreamPartTypeToolInputEnd,
385 ID: existingToolCall.id,
386 }) {
387 return
388 }
389
390 if !yield(ai.StreamPart{
391 Type: ai.StreamPartTypeToolCall,
392 ID: existingToolCall.id,
393 ToolCallName: existingToolCall.name,
394 ToolCallInput: existingToolCall.arguments,
395 }) {
396 return
397 }
398 existingToolCall.hasFinished = true
399 toolCalls[toolCallDelta.Index] = existingToolCall
400 }
401 } else {
402 // Does not exist
403 var err error
404 if toolCallDelta.Type != "function" {
405 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
406 }
407 if toolCallDelta.ID == "" {
408 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
409 }
410 if toolCallDelta.Function.Name == "" {
411 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
412 }
413 if err != nil {
414 yield(ai.StreamPart{
415 Type: ai.StreamPartTypeError,
416 Error: o.handleError(stream.Err()),
417 })
418 return
419 }
420
421 if !yield(ai.StreamPart{
422 Type: ai.StreamPartTypeToolInputStart,
423 ID: toolCallDelta.ID,
424 ToolCallName: toolCallDelta.Function.Name,
425 }) {
426 return
427 }
428 toolCalls[toolCallDelta.Index] = streamToolCall{
429 id: toolCallDelta.ID,
430 name: toolCallDelta.Function.Name,
431 arguments: toolCallDelta.Function.Arguments,
432 }
433
434 exTc := toolCalls[toolCallDelta.Index]
435 if exTc.arguments != "" {
436 if !yield(ai.StreamPart{
437 Type: ai.StreamPartTypeToolInputDelta,
438 ID: exTc.id,
439 Delta: exTc.arguments,
440 }) {
441 return
442 }
443 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
444 if !yield(ai.StreamPart{
445 Type: ai.StreamPartTypeToolInputEnd,
446 ID: toolCallDelta.ID,
447 }) {
448 return
449 }
450
451 if !yield(ai.StreamPart{
452 Type: ai.StreamPartTypeToolCall,
453 ID: exTc.id,
454 ToolCallName: exTc.name,
455 ToolCallInput: exTc.arguments,
456 }) {
457 return
458 }
459 exTc.hasFinished = true
460 toolCalls[toolCallDelta.Index] = exTc
461 }
462 }
463 continue
464 }
465 }
466 }
467 }
468
469 // Check for annotations in the delta's raw JSON
470 for _, choice := range chunk.Choices {
471 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
472 for _, annotation := range annotations {
473 if annotation.Type == "url_citation" {
474 if !yield(ai.StreamPart{
475 Type: ai.StreamPartTypeSource,
476 ID: uuid.NewString(),
477 SourceType: ai.SourceTypeURL,
478 URL: annotation.URLCitation.URL,
479 Title: annotation.URLCitation.Title,
480 }) {
481 return
482 }
483 }
484 }
485 }
486 }
487 }
488 err := stream.Err()
489 if err == nil || errors.Is(err, io.EOF) {
490 // finished
491 if isActiveText {
492 isActiveText = false
493 if !yield(ai.StreamPart{
494 Type: ai.StreamPartTypeTextEnd,
495 ID: "0",
496 }) {
497 return
498 }
499 }
500
501 // Add logprobs if available
502 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
503 streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
504 }
505
506 // Handle annotations/citations from accumulated response
507 if len(acc.Choices) > 0 {
508 for _, annotation := range acc.Choices[0].Message.Annotations {
509 if annotation.Type == "url_citation" {
510 if !yield(ai.StreamPart{
511 Type: ai.StreamPartTypeSource,
512 ID: acc.ID,
513 SourceType: ai.SourceTypeURL,
514 URL: annotation.URLCitation.URL,
515 Title: annotation.URLCitation.Title,
516 }) {
517 return
518 }
519 }
520 }
521 }
522
523 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
524 yield(ai.StreamPart{
525 Type: ai.StreamPartTypeFinish,
526 Usage: usage,
527 FinishReason: finishReason,
528 ProviderMetadata: ai.ProviderMetadata{
529 Name: streamProviderMetadata,
530 },
531 })
532 return
533 } else {
534 yield(ai.StreamPart{
535 Type: ai.StreamPartTypeError,
536 Error: o.handleError(err),
537 })
538 return
539 }
540 }, nil
541}
542
543func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
544 switch finishReason {
545 case "stop":
546 return ai.FinishReasonStop
547 case "length":
548 return ai.FinishReasonLength
549 case "content_filter":
550 return ai.FinishReasonContentFilter
551 case "function_call", "tool_calls":
552 return ai.FinishReasonToolCalls
553 default:
554 return ai.FinishReasonUnknown
555 }
556}
557
558func isReasoningModel(modelID string) bool {
559 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
560}
561
562func isSearchPreviewModel(modelID string) bool {
563 return strings.Contains(modelID, "search-preview")
564}
565
566func supportsFlexProcessing(modelID string) bool {
567 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
568}
569
570func supportsPriorityProcessing(modelID string) bool {
571 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
572 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
573 strings.HasPrefix(modelID, "o4-mini")
574}
575
576func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
577 for _, tool := range tools {
578 if tool.GetType() == ai.ToolTypeFunction {
579 ft, ok := tool.(ai.FunctionTool)
580 if !ok {
581 continue
582 }
583 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
584 OfFunction: &openai.ChatCompletionFunctionToolParam{
585 Function: shared.FunctionDefinitionParam{
586 Name: ft.Name,
587 Description: param.NewOpt(ft.Description),
588 Parameters: openai.FunctionParameters(ft.InputSchema),
589 Strict: param.NewOpt(false),
590 },
591 Type: "function",
592 },
593 })
594 continue
595 }
596
597 // TODO: handle provider tool calls
598 warnings = append(warnings, ai.CallWarning{
599 Type: ai.CallWarningTypeUnsupportedTool,
600 Tool: tool,
601 Message: "tool is not supported",
602 })
603 }
604 if toolChoice == nil {
605 return openAiTools, openAiToolChoice, warnings
606 }
607
608 switch *toolChoice {
609 case ai.ToolChoiceAuto:
610 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
611 OfAuto: param.NewOpt("auto"),
612 }
613 case ai.ToolChoiceNone:
614 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
615 OfAuto: param.NewOpt("none"),
616 }
617 default:
618 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
619 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
620 Type: "function",
621 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
622 Name: string(*toolChoice),
623 },
624 },
625 }
626 }
627 return openAiTools, openAiToolChoice, warnings
628}
629
630func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
631 var messages []openai.ChatCompletionMessageParamUnion
632 var warnings []ai.CallWarning
633 for _, msg := range prompt {
634 switch msg.Role {
635 case ai.MessageRoleSystem:
636 var systemPromptParts []string
637 for _, c := range msg.Content {
638 if c.GetType() != ai.ContentTypeText {
639 warnings = append(warnings, ai.CallWarning{
640 Type: ai.CallWarningTypeOther,
641 Message: "system prompt can only have text content",
642 })
643 continue
644 }
645 textPart, ok := ai.AsContentType[ai.TextPart](c)
646 if !ok {
647 warnings = append(warnings, ai.CallWarning{
648 Type: ai.CallWarningTypeOther,
649 Message: "system prompt text part does not have the right type",
650 })
651 continue
652 }
653 text := textPart.Text
654 if strings.TrimSpace(text) != "" {
655 systemPromptParts = append(systemPromptParts, textPart.Text)
656 }
657 }
658 if len(systemPromptParts) == 0 {
659 warnings = append(warnings, ai.CallWarning{
660 Type: ai.CallWarningTypeOther,
661 Message: "system prompt has no text parts",
662 })
663 continue
664 }
665 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
666 case ai.MessageRoleUser:
667 // simple user message just text content
668 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
669 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
670 if !ok {
671 warnings = append(warnings, ai.CallWarning{
672 Type: ai.CallWarningTypeOther,
673 Message: "user message text part does not have the right type",
674 })
675 continue
676 }
677 messages = append(messages, openai.UserMessage(textPart.Text))
678 continue
679 }
680 // text content and attachments
681 // for now we only support image content later we need to check
682 // TODO: add the supported media types to the language model so we
683 // can use that to validate the data here.
684 var content []openai.ChatCompletionContentPartUnionParam
685 for _, c := range msg.Content {
686 switch c.GetType() {
687 case ai.ContentTypeText:
688 textPart, ok := ai.AsContentType[ai.TextPart](c)
689 if !ok {
690 warnings = append(warnings, ai.CallWarning{
691 Type: ai.CallWarningTypeOther,
692 Message: "user message text part does not have the right type",
693 })
694 continue
695 }
696 content = append(content, openai.ChatCompletionContentPartUnionParam{
697 OfText: &openai.ChatCompletionContentPartTextParam{
698 Text: textPart.Text,
699 },
700 })
701 case ai.ContentTypeFile:
702 filePart, ok := ai.AsContentType[ai.FilePart](c)
703 if !ok {
704 warnings = append(warnings, ai.CallWarning{
705 Type: ai.CallWarningTypeOther,
706 Message: "user message file part does not have the right type",
707 })
708 continue
709 }
710
711 switch {
712 case strings.HasPrefix(filePart.MediaType, "image/"):
713 // Handle image files
714 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
715 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
716 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
717
718 // Check for provider-specific options like image detail
719 if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
720 if detail, ok := providerOptions.(*ProviderFileOptions); ok {
721 imageURL.Detail = detail.ImageDetail
722 }
723 }
724
725 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
726 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
727
728 case filePart.MediaType == "audio/wav":
729 // Handle WAV audio files
730 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
731 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
732 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
733 Data: base64Encoded,
734 Format: "wav",
735 },
736 }
737 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
738
739 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
740 // Handle MP3 audio files
741 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
742 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
743 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
744 Data: base64Encoded,
745 Format: "mp3",
746 },
747 }
748 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
749
750 case filePart.MediaType == "application/pdf":
751 // Handle PDF files
752 dataStr := string(filePart.Data)
753
754 // Check if data looks like a file ID (starts with "file-")
755 if strings.HasPrefix(dataStr, "file-") {
756 fileBlock := openai.ChatCompletionContentPartFileParam{
757 File: openai.ChatCompletionContentPartFileFileParam{
758 FileID: param.NewOpt(dataStr),
759 },
760 }
761 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
762 } else {
763 // Handle as base64 data
764 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
765 data := "data:application/pdf;base64," + base64Encoded
766
767 filename := filePart.Filename
768 if filename == "" {
769 // Generate default filename based on content index
770 filename = fmt.Sprintf("part-%d.pdf", len(content))
771 }
772
773 fileBlock := openai.ChatCompletionContentPartFileParam{
774 File: openai.ChatCompletionContentPartFileFileParam{
775 Filename: param.NewOpt(filename),
776 FileData: param.NewOpt(data),
777 },
778 }
779 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
780 }
781
782 default:
783 warnings = append(warnings, ai.CallWarning{
784 Type: ai.CallWarningTypeOther,
785 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
786 })
787 }
788 }
789 }
790 messages = append(messages, openai.UserMessage(content))
791 case ai.MessageRoleAssistant:
792 // simple assistant message just text content
793 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
794 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
795 if !ok {
796 warnings = append(warnings, ai.CallWarning{
797 Type: ai.CallWarningTypeOther,
798 Message: "assistant message text part does not have the right type",
799 })
800 continue
801 }
802 messages = append(messages, openai.AssistantMessage(textPart.Text))
803 continue
804 }
805 assistantMsg := openai.ChatCompletionAssistantMessageParam{
806 Role: "assistant",
807 }
808 for _, c := range msg.Content {
809 switch c.GetType() {
810 case ai.ContentTypeText:
811 textPart, ok := ai.AsContentType[ai.TextPart](c)
812 if !ok {
813 warnings = append(warnings, ai.CallWarning{
814 Type: ai.CallWarningTypeOther,
815 Message: "assistant message text part does not have the right type",
816 })
817 continue
818 }
819 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
820 OfString: param.NewOpt(textPart.Text),
821 }
822 case ai.ContentTypeToolCall:
823 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
824 if !ok {
825 warnings = append(warnings, ai.CallWarning{
826 Type: ai.CallWarningTypeOther,
827 Message: "assistant message tool part does not have the right type",
828 })
829 continue
830 }
831 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
832 openai.ChatCompletionMessageToolCallUnionParam{
833 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
834 ID: toolCallPart.ToolCallID,
835 Type: "function",
836 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
837 Name: toolCallPart.ToolName,
838 Arguments: toolCallPart.Input,
839 },
840 },
841 })
842 }
843 }
844 messages = append(messages, openai.ChatCompletionMessageParamUnion{
845 OfAssistant: &assistantMsg,
846 })
847 case ai.MessageRoleTool:
848 for _, c := range msg.Content {
849 if c.GetType() != ai.ContentTypeToolResult {
850 warnings = append(warnings, ai.CallWarning{
851 Type: ai.CallWarningTypeOther,
852 Message: "tool message can only have tool result content",
853 })
854 continue
855 }
856
857 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
858 if !ok {
859 warnings = append(warnings, ai.CallWarning{
860 Type: ai.CallWarningTypeOther,
861 Message: "tool message result part does not have the right type",
862 })
863 continue
864 }
865
866 switch toolResultPart.Output.GetType() {
867 case ai.ToolResultContentTypeText:
868 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
869 if !ok {
870 warnings = append(warnings, ai.CallWarning{
871 Type: ai.CallWarningTypeOther,
872 Message: "tool result output does not have the right type",
873 })
874 continue
875 }
876 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
877 case ai.ToolResultContentTypeError:
878 // TODO: check if better handling is needed
879 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
880 if !ok {
881 warnings = append(warnings, ai.CallWarning{
882 Type: ai.CallWarningTypeOther,
883 Message: "tool result output does not have the right type",
884 })
885 continue
886 }
887 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
888 }
889 }
890 }
891 }
892 return messages, warnings
893}
894
895// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
896func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
897 var annotations []openai.ChatCompletionMessageAnnotation
898
899 // Parse the raw JSON to extract annotations
900 var deltaData map[string]any
901 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
902 return annotations
903 }
904
905 // Check if annotations exist in the delta
906 if annotationsData, ok := deltaData["annotations"].([]any); ok {
907 for _, annotationData := range annotationsData {
908 if annotationMap, ok := annotationData.(map[string]any); ok {
909 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
910 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
911 annotation := openai.ChatCompletionMessageAnnotation{
912 Type: "url_citation",
913 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
914 URL: urlCitationData["url"].(string),
915 Title: urlCitationData["title"].(string),
916 },
917 }
918 annotations = append(annotations, annotation)
919 }
920 }
921 }
922 }
923 }
924
925 return annotations
926}