1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "strings"
11
12 "github.com/charmbracelet/fantasy/ai"
13 xjson "github.com/charmbracelet/x/json"
14 "github.com/google/uuid"
15 "github.com/openai/openai-go/v2"
16 "github.com/openai/openai-go/v2/packages/param"
17 "github.com/openai/openai-go/v2/shared"
18)
19
20type languageModel struct {
21 provider string
22 modelID string
23 client openai.Client
24 prepareCallFunc LanguageModelPrepareCallFunc
25 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
26 extraContentFunc LanguageModelExtraContentFunc
27 usageFunc LanguageModelUsageFunc
28 streamUsageFunc LanguageModelStreamUsageFunc
29 streamExtraFunc LanguageModelStreamExtraFunc
30 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
31}
32
33type LanguageModelOption = func(*languageModel)
34
35func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
36 return func(l *languageModel) {
37 l.prepareCallFunc = fn
38 }
39}
40
41func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
42 return func(l *languageModel) {
43 l.mapFinishReasonFunc = fn
44 }
45}
46
47func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
48 return func(l *languageModel) {
49 l.extraContentFunc = fn
50 }
51}
52
53func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
54 return func(l *languageModel) {
55 l.streamExtraFunc = fn
56 }
57}
58
59func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
60 return func(l *languageModel) {
61 l.usageFunc = fn
62 }
63}
64
65func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
66 return func(l *languageModel) {
67 l.streamUsageFunc = fn
68 }
69}
70
71func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
72 model := languageModel{
73 modelID: modelID,
74 provider: provider,
75 client: client,
76 prepareCallFunc: defaultPrepareLanguageModelCall,
77 mapFinishReasonFunc: defaultMapFinishReason,
78 usageFunc: defaultUsage,
79 streamUsageFunc: defaultStreamUsage,
80 streamProviderMetadataFunc: defaultStreamProviderMetadataFunc,
81 }
82
83 for _, o := range opts {
84 o(&model)
85 }
86 return model
87}
88
89type streamToolCall struct {
90 id string
91 name string
92 arguments string
93 hasFinished bool
94}
95
96// Model implements ai.LanguageModel.
97func (o languageModel) Model() string {
98 return o.modelID
99}
100
101// Provider implements ai.LanguageModel.
102func (o languageModel) Provider() string {
103 return o.provider
104}
105
106func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
107 params := &openai.ChatCompletionNewParams{}
108 messages, warnings := toPrompt(call.Prompt)
109 if call.TopK != nil {
110 warnings = append(warnings, ai.CallWarning{
111 Type: ai.CallWarningTypeUnsupportedSetting,
112 Setting: "top_k",
113 })
114 }
115
116 if call.MaxOutputTokens != nil {
117 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
118 }
119 if call.Temperature != nil {
120 params.Temperature = param.NewOpt(*call.Temperature)
121 }
122 if call.TopP != nil {
123 params.TopP = param.NewOpt(*call.TopP)
124 }
125 if call.FrequencyPenalty != nil {
126 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
127 }
128 if call.PresencePenalty != nil {
129 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
130 }
131
132 if isReasoningModel(o.modelID) {
133 // remove unsupported settings for reasoning models
134 // see https://platform.openai.com/docs/guides/reasoning#limitations
135 if call.Temperature != nil {
136 params.Temperature = param.Opt[float64]{}
137 warnings = append(warnings, ai.CallWarning{
138 Type: ai.CallWarningTypeUnsupportedSetting,
139 Setting: "temperature",
140 Details: "temperature is not supported for reasoning models",
141 })
142 }
143 if call.TopP != nil {
144 params.TopP = param.Opt[float64]{}
145 warnings = append(warnings, ai.CallWarning{
146 Type: ai.CallWarningTypeUnsupportedSetting,
147 Setting: "TopP",
148 Details: "TopP is not supported for reasoning models",
149 })
150 }
151 if call.FrequencyPenalty != nil {
152 params.FrequencyPenalty = param.Opt[float64]{}
153 warnings = append(warnings, ai.CallWarning{
154 Type: ai.CallWarningTypeUnsupportedSetting,
155 Setting: "FrequencyPenalty",
156 Details: "FrequencyPenalty is not supported for reasoning models",
157 })
158 }
159 if call.PresencePenalty != nil {
160 params.PresencePenalty = param.Opt[float64]{}
161 warnings = append(warnings, ai.CallWarning{
162 Type: ai.CallWarningTypeUnsupportedSetting,
163 Setting: "PresencePenalty",
164 Details: "PresencePenalty is not supported for reasoning models",
165 })
166 }
167
168 // reasoning models use max_completion_tokens instead of max_tokens
169 if call.MaxOutputTokens != nil {
170 if !params.MaxCompletionTokens.Valid() {
171 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
172 }
173 params.MaxTokens = param.Opt[int64]{}
174 }
175 }
176
177 // Handle search preview models
178 if isSearchPreviewModel(o.modelID) {
179 if call.Temperature != nil {
180 params.Temperature = param.Opt[float64]{}
181 warnings = append(warnings, ai.CallWarning{
182 Type: ai.CallWarningTypeUnsupportedSetting,
183 Setting: "temperature",
184 Details: "temperature is not supported for the search preview models and has been removed.",
185 })
186 }
187 }
188
189 optionsWarnings, err := o.prepareCallFunc(o, params, call)
190 if err != nil {
191 return nil, nil, err
192 }
193
194 if len(optionsWarnings) > 0 {
195 warnings = append(warnings, optionsWarnings...)
196 }
197
198 params.Messages = messages
199 params.Model = o.modelID
200
201 if len(call.Tools) > 0 {
202 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
203 params.Tools = tools
204 if toolChoice != nil {
205 params.ToolChoice = *toolChoice
206 }
207 warnings = append(warnings, toolWarnings...)
208 }
209 return params, warnings, nil
210}
211
212func (o languageModel) handleError(err error) error {
213 var apiErr *openai.Error
214 if errors.As(err, &apiErr) {
215 requestDump := apiErr.DumpRequest(true)
216 responseDump := apiErr.DumpResponse(true)
217 headers := map[string]string{}
218 for k, h := range apiErr.Response.Header {
219 v := h[len(h)-1]
220 headers[strings.ToLower(k)] = v
221 }
222 return ai.NewAPICallError(
223 apiErr.Message,
224 apiErr.Request.URL.String(),
225 string(requestDump),
226 apiErr.StatusCode,
227 headers,
228 string(responseDump),
229 apiErr,
230 false,
231 )
232 }
233 return err
234}
235
236// Generate implements ai.LanguageModel.
237func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
238 params, warnings, err := o.prepareParams(call)
239 if err != nil {
240 return nil, err
241 }
242 response, err := o.client.Chat.Completions.New(ctx, *params)
243 if err != nil {
244 return nil, o.handleError(err)
245 }
246
247 if len(response.Choices) == 0 {
248 return nil, errors.New("no response generated")
249 }
250 choice := response.Choices[0]
251 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
252 text := choice.Message.Content
253 if text != "" {
254 content = append(content, ai.TextContent{
255 Text: text,
256 })
257 }
258 if o.extraContentFunc != nil {
259 extraContent := o.extraContentFunc(choice)
260 content = append(content, extraContent...)
261 }
262 for _, tc := range choice.Message.ToolCalls {
263 toolCallID := tc.ID
264 if toolCallID == "" {
265 toolCallID = uuid.NewString()
266 }
267 content = append(content, ai.ToolCallContent{
268 ProviderExecuted: false, // TODO: update when handling other tools
269 ToolCallID: toolCallID,
270 ToolName: tc.Function.Name,
271 Input: tc.Function.Arguments,
272 })
273 }
274 // Handle annotations/citations
275 for _, annotation := range choice.Message.Annotations {
276 if annotation.Type == "url_citation" {
277 content = append(content, ai.SourceContent{
278 SourceType: ai.SourceTypeURL,
279 ID: uuid.NewString(),
280 URL: annotation.URLCitation.URL,
281 Title: annotation.URLCitation.Title,
282 })
283 }
284 }
285
286 usage, providerMetadata := o.usageFunc(*response)
287
288 return &ai.Response{
289 Content: content,
290 Usage: usage,
291 FinishReason: defaultMapFinishReason(choice),
292 ProviderMetadata: ai.ProviderMetadata{
293 Name: providerMetadata,
294 },
295 Warnings: warnings,
296 }, nil
297}
298
299// Stream implements ai.LanguageModel.
300func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
301 params, warnings, err := o.prepareParams(call)
302 if err != nil {
303 return nil, err
304 }
305
306 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
307 IncludeUsage: openai.Bool(true),
308 }
309
310 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
311 isActiveText := false
312 toolCalls := make(map[int64]streamToolCall)
313
314 // Build provider metadata for streaming
315 providerMetadata := ai.ProviderMetadata{
316 Name: &ProviderMetadata{},
317 }
318 acc := openai.ChatCompletionAccumulator{}
319 extraContext := make(map[string]any)
320 var usage ai.Usage
321 return func(yield func(ai.StreamPart) bool) {
322 if len(warnings) > 0 {
323 if !yield(ai.StreamPart{
324 Type: ai.StreamPartTypeWarnings,
325 Warnings: warnings,
326 }) {
327 return
328 }
329 }
330 for stream.Next() {
331 chunk := stream.Current()
332 acc.AddChunk(chunk)
333 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
334 if len(chunk.Choices) == 0 {
335 continue
336 }
337 for _, choice := range chunk.Choices {
338 switch {
339 case choice.Delta.Content != "":
340 if !isActiveText {
341 isActiveText = true
342 if !yield(ai.StreamPart{
343 Type: ai.StreamPartTypeTextStart,
344 ID: "0",
345 }) {
346 return
347 }
348 }
349 if !yield(ai.StreamPart{
350 Type: ai.StreamPartTypeTextDelta,
351 ID: "0",
352 Delta: choice.Delta.Content,
353 }) {
354 return
355 }
356 case len(choice.Delta.ToolCalls) > 0:
357 if isActiveText {
358 isActiveText = false
359 if !yield(ai.StreamPart{
360 Type: ai.StreamPartTypeTextEnd,
361 ID: "0",
362 }) {
363 return
364 }
365 }
366
367 for _, toolCallDelta := range choice.Delta.ToolCalls {
368 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
369 if existingToolCall.hasFinished {
370 continue
371 }
372 if toolCallDelta.Function.Arguments != "" {
373 existingToolCall.arguments += toolCallDelta.Function.Arguments
374 }
375 if !yield(ai.StreamPart{
376 Type: ai.StreamPartTypeToolInputDelta,
377 ID: existingToolCall.id,
378 Delta: toolCallDelta.Function.Arguments,
379 }) {
380 return
381 }
382 toolCalls[toolCallDelta.Index] = existingToolCall
383 if xjson.IsValid(existingToolCall.arguments) {
384 if !yield(ai.StreamPart{
385 Type: ai.StreamPartTypeToolInputEnd,
386 ID: existingToolCall.id,
387 }) {
388 return
389 }
390
391 if !yield(ai.StreamPart{
392 Type: ai.StreamPartTypeToolCall,
393 ID: existingToolCall.id,
394 ToolCallName: existingToolCall.name,
395 ToolCallInput: existingToolCall.arguments,
396 }) {
397 return
398 }
399 existingToolCall.hasFinished = true
400 toolCalls[toolCallDelta.Index] = existingToolCall
401 }
402 } else {
403 // Does not exist
404 var err error
405 if toolCallDelta.Type != "function" {
406 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
407 }
408 if toolCallDelta.ID == "" {
409 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
410 }
411 if toolCallDelta.Function.Name == "" {
412 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
413 }
414 if err != nil {
415 yield(ai.StreamPart{
416 Type: ai.StreamPartTypeError,
417 Error: o.handleError(stream.Err()),
418 })
419 return
420 }
421
422 if !yield(ai.StreamPart{
423 Type: ai.StreamPartTypeToolInputStart,
424 ID: toolCallDelta.ID,
425 ToolCallName: toolCallDelta.Function.Name,
426 }) {
427 return
428 }
429 toolCalls[toolCallDelta.Index] = streamToolCall{
430 id: toolCallDelta.ID,
431 name: toolCallDelta.Function.Name,
432 arguments: toolCallDelta.Function.Arguments,
433 }
434
435 exTc := toolCalls[toolCallDelta.Index]
436 if exTc.arguments != "" {
437 if !yield(ai.StreamPart{
438 Type: ai.StreamPartTypeToolInputDelta,
439 ID: exTc.id,
440 Delta: exTc.arguments,
441 }) {
442 return
443 }
444 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
445 if !yield(ai.StreamPart{
446 Type: ai.StreamPartTypeToolInputEnd,
447 ID: toolCallDelta.ID,
448 }) {
449 return
450 }
451
452 if !yield(ai.StreamPart{
453 Type: ai.StreamPartTypeToolCall,
454 ID: exTc.id,
455 ToolCallName: exTc.name,
456 ToolCallInput: exTc.arguments,
457 }) {
458 return
459 }
460 exTc.hasFinished = true
461 toolCalls[toolCallDelta.Index] = exTc
462 }
463 }
464 continue
465 }
466 }
467 }
468
469 if o.streamExtraFunc != nil {
470 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
471 if !shouldContinue {
472 return
473 }
474 extraContext = updatedContext
475 }
476 }
477
478 // Check for annotations in the delta's raw JSON
479 for _, choice := range chunk.Choices {
480 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
481 for _, annotation := range annotations {
482 if annotation.Type == "url_citation" {
483 if !yield(ai.StreamPart{
484 Type: ai.StreamPartTypeSource,
485 ID: uuid.NewString(),
486 SourceType: ai.SourceTypeURL,
487 URL: annotation.URLCitation.URL,
488 Title: annotation.URLCitation.Title,
489 }) {
490 return
491 }
492 }
493 }
494 }
495 }
496 }
497 err := stream.Err()
498 if err == nil || errors.Is(err, io.EOF) {
499 // finished
500 if isActiveText {
501 isActiveText = false
502 if !yield(ai.StreamPart{
503 Type: ai.StreamPartTypeTextEnd,
504 ID: "0",
505 }) {
506 return
507 }
508 }
509
510 if len(acc.Choices) > 0 {
511 choice := acc.Choices[0]
512 // Add logprobs if available
513 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
514
515 // Handle annotations/citations from accumulated response
516 for _, annotation := range choice.Message.Annotations {
517 if annotation.Type == "url_citation" {
518 if !yield(ai.StreamPart{
519 Type: ai.StreamPartTypeSource,
520 ID: acc.ID,
521 SourceType: ai.SourceTypeURL,
522 URL: annotation.URLCitation.URL,
523 Title: annotation.URLCitation.Title,
524 }) {
525 return
526 }
527 }
528 }
529 }
530 finishReason := ai.FinishReasonUnknown
531 if len(acc.Choices) > 0 {
532 finishReason = o.mapFinishReasonFunc(acc.Choices[0])
533 }
534 yield(ai.StreamPart{
535 Type: ai.StreamPartTypeFinish,
536 Usage: usage,
537 FinishReason: finishReason,
538 ProviderMetadata: providerMetadata,
539 })
540 return
541 } else {
542 yield(ai.StreamPart{
543 Type: ai.StreamPartTypeError,
544 Error: o.handleError(err),
545 })
546 return
547 }
548 }, nil
549}
550
551func isReasoningModel(modelID string) bool {
552 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
553}
554
555func isSearchPreviewModel(modelID string) bool {
556 return strings.Contains(modelID, "search-preview")
557}
558
559func supportsFlexProcessing(modelID string) bool {
560 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
561}
562
563func supportsPriorityProcessing(modelID string) bool {
564 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
565 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
566 strings.HasPrefix(modelID, "o4-mini")
567}
568
569func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
570 for _, tool := range tools {
571 if tool.GetType() == ai.ToolTypeFunction {
572 ft, ok := tool.(ai.FunctionTool)
573 if !ok {
574 continue
575 }
576 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
577 OfFunction: &openai.ChatCompletionFunctionToolParam{
578 Function: shared.FunctionDefinitionParam{
579 Name: ft.Name,
580 Description: param.NewOpt(ft.Description),
581 Parameters: openai.FunctionParameters(ft.InputSchema),
582 Strict: param.NewOpt(false),
583 },
584 Type: "function",
585 },
586 })
587 continue
588 }
589
590 // TODO: handle provider tool calls
591 warnings = append(warnings, ai.CallWarning{
592 Type: ai.CallWarningTypeUnsupportedTool,
593 Tool: tool,
594 Message: "tool is not supported",
595 })
596 }
597 if toolChoice == nil {
598 return openAiTools, openAiToolChoice, warnings
599 }
600
601 switch *toolChoice {
602 case ai.ToolChoiceAuto:
603 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
604 OfAuto: param.NewOpt("auto"),
605 }
606 case ai.ToolChoiceNone:
607 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
608 OfAuto: param.NewOpt("none"),
609 }
610 default:
611 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
612 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
613 Type: "function",
614 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
615 Name: string(*toolChoice),
616 },
617 },
618 }
619 }
620 return openAiTools, openAiToolChoice, warnings
621}
622
623func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
624 var messages []openai.ChatCompletionMessageParamUnion
625 var warnings []ai.CallWarning
626 for _, msg := range prompt {
627 switch msg.Role {
628 case ai.MessageRoleSystem:
629 var systemPromptParts []string
630 for _, c := range msg.Content {
631 if c.GetType() != ai.ContentTypeText {
632 warnings = append(warnings, ai.CallWarning{
633 Type: ai.CallWarningTypeOther,
634 Message: "system prompt can only have text content",
635 })
636 continue
637 }
638 textPart, ok := ai.AsContentType[ai.TextPart](c)
639 if !ok {
640 warnings = append(warnings, ai.CallWarning{
641 Type: ai.CallWarningTypeOther,
642 Message: "system prompt text part does not have the right type",
643 })
644 continue
645 }
646 text := textPart.Text
647 if strings.TrimSpace(text) != "" {
648 systemPromptParts = append(systemPromptParts, textPart.Text)
649 }
650 }
651 if len(systemPromptParts) == 0 {
652 warnings = append(warnings, ai.CallWarning{
653 Type: ai.CallWarningTypeOther,
654 Message: "system prompt has no text parts",
655 })
656 continue
657 }
658 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
659 case ai.MessageRoleUser:
660 // simple user message just text content
661 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
662 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
663 if !ok {
664 warnings = append(warnings, ai.CallWarning{
665 Type: ai.CallWarningTypeOther,
666 Message: "user message text part does not have the right type",
667 })
668 continue
669 }
670 messages = append(messages, openai.UserMessage(textPart.Text))
671 continue
672 }
673 // text content and attachments
674 // for now we only support image content later we need to check
675 // TODO: add the supported media types to the language model so we
676 // can use that to validate the data here.
677 var content []openai.ChatCompletionContentPartUnionParam
678 for _, c := range msg.Content {
679 switch c.GetType() {
680 case ai.ContentTypeText:
681 textPart, ok := ai.AsContentType[ai.TextPart](c)
682 if !ok {
683 warnings = append(warnings, ai.CallWarning{
684 Type: ai.CallWarningTypeOther,
685 Message: "user message text part does not have the right type",
686 })
687 continue
688 }
689 content = append(content, openai.ChatCompletionContentPartUnionParam{
690 OfText: &openai.ChatCompletionContentPartTextParam{
691 Text: textPart.Text,
692 },
693 })
694 case ai.ContentTypeFile:
695 filePart, ok := ai.AsContentType[ai.FilePart](c)
696 if !ok {
697 warnings = append(warnings, ai.CallWarning{
698 Type: ai.CallWarningTypeOther,
699 Message: "user message file part does not have the right type",
700 })
701 continue
702 }
703
704 switch {
705 case strings.HasPrefix(filePart.MediaType, "image/"):
706 // Handle image files
707 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
708 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
709 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
710
711 // Check for provider-specific options like image detail
712 if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
713 if detail, ok := providerOptions.(*ProviderFileOptions); ok {
714 imageURL.Detail = detail.ImageDetail
715 }
716 }
717
718 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
719 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
720
721 case filePart.MediaType == "audio/wav":
722 // Handle WAV audio files
723 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
724 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
725 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
726 Data: base64Encoded,
727 Format: "wav",
728 },
729 }
730 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
731
732 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
733 // Handle MP3 audio files
734 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
735 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
736 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
737 Data: base64Encoded,
738 Format: "mp3",
739 },
740 }
741 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
742
743 case filePart.MediaType == "application/pdf":
744 // Handle PDF files
745 dataStr := string(filePart.Data)
746
747 // Check if data looks like a file ID (starts with "file-")
748 if strings.HasPrefix(dataStr, "file-") {
749 fileBlock := openai.ChatCompletionContentPartFileParam{
750 File: openai.ChatCompletionContentPartFileFileParam{
751 FileID: param.NewOpt(dataStr),
752 },
753 }
754 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
755 } else {
756 // Handle as base64 data
757 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
758 data := "data:application/pdf;base64," + base64Encoded
759
760 filename := filePart.Filename
761 if filename == "" {
762 // Generate default filename based on content index
763 filename = fmt.Sprintf("part-%d.pdf", len(content))
764 }
765
766 fileBlock := openai.ChatCompletionContentPartFileParam{
767 File: openai.ChatCompletionContentPartFileFileParam{
768 Filename: param.NewOpt(filename),
769 FileData: param.NewOpt(data),
770 },
771 }
772 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
773 }
774
775 default:
776 warnings = append(warnings, ai.CallWarning{
777 Type: ai.CallWarningTypeOther,
778 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
779 })
780 }
781 }
782 }
783 messages = append(messages, openai.UserMessage(content))
784 case ai.MessageRoleAssistant:
785 // simple assistant message just text content
786 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
787 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
788 if !ok {
789 warnings = append(warnings, ai.CallWarning{
790 Type: ai.CallWarningTypeOther,
791 Message: "assistant message text part does not have the right type",
792 })
793 continue
794 }
795 messages = append(messages, openai.AssistantMessage(textPart.Text))
796 continue
797 }
798 assistantMsg := openai.ChatCompletionAssistantMessageParam{
799 Role: "assistant",
800 }
801 for _, c := range msg.Content {
802 switch c.GetType() {
803 case ai.ContentTypeText:
804 textPart, ok := ai.AsContentType[ai.TextPart](c)
805 if !ok {
806 warnings = append(warnings, ai.CallWarning{
807 Type: ai.CallWarningTypeOther,
808 Message: "assistant message text part does not have the right type",
809 })
810 continue
811 }
812 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
813 OfString: param.NewOpt(textPart.Text),
814 }
815 case ai.ContentTypeToolCall:
816 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
817 if !ok {
818 warnings = append(warnings, ai.CallWarning{
819 Type: ai.CallWarningTypeOther,
820 Message: "assistant message tool part does not have the right type",
821 })
822 continue
823 }
824 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
825 openai.ChatCompletionMessageToolCallUnionParam{
826 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
827 ID: toolCallPart.ToolCallID,
828 Type: "function",
829 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
830 Name: toolCallPart.ToolName,
831 Arguments: toolCallPart.Input,
832 },
833 },
834 })
835 }
836 }
837 messages = append(messages, openai.ChatCompletionMessageParamUnion{
838 OfAssistant: &assistantMsg,
839 })
840 case ai.MessageRoleTool:
841 for _, c := range msg.Content {
842 if c.GetType() != ai.ContentTypeToolResult {
843 warnings = append(warnings, ai.CallWarning{
844 Type: ai.CallWarningTypeOther,
845 Message: "tool message can only have tool result content",
846 })
847 continue
848 }
849
850 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
851 if !ok {
852 warnings = append(warnings, ai.CallWarning{
853 Type: ai.CallWarningTypeOther,
854 Message: "tool message result part does not have the right type",
855 })
856 continue
857 }
858
859 switch toolResultPart.Output.GetType() {
860 case ai.ToolResultContentTypeText:
861 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
862 if !ok {
863 warnings = append(warnings, ai.CallWarning{
864 Type: ai.CallWarningTypeOther,
865 Message: "tool result output does not have the right type",
866 })
867 continue
868 }
869 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
870 case ai.ToolResultContentTypeError:
871 // TODO: check if better handling is needed
872 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
873 if !ok {
874 warnings = append(warnings, ai.CallWarning{
875 Type: ai.CallWarningTypeOther,
876 Message: "tool result output does not have the right type",
877 })
878 continue
879 }
880 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
881 }
882 }
883 }
884 }
885 return messages, warnings
886}
887
888// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
889func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
890 var annotations []openai.ChatCompletionMessageAnnotation
891
892 // Parse the raw JSON to extract annotations
893 var deltaData map[string]any
894 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
895 return annotations
896 }
897
898 // Check if annotations exist in the delta
899 if annotationsData, ok := deltaData["annotations"].([]any); ok {
900 for _, annotationData := range annotationsData {
901 if annotationMap, ok := annotationData.(map[string]any); ok {
902 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
903 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
904 annotation := openai.ChatCompletionMessageAnnotation{
905 Type: "url_citation",
906 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
907 URL: urlCitationData["url"].(string),
908 Title: urlCitationData["title"].(string),
909 },
910 }
911 annotations = append(annotations, annotation)
912 }
913 }
914 }
915 }
916 }
917
918 return annotations
919}