1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "strings"
11
12 "github.com/charmbracelet/fantasy/ai"
13 xjson "github.com/charmbracelet/x/json"
14 "github.com/google/uuid"
15 "github.com/openai/openai-go/v2"
16 "github.com/openai/openai-go/v2/packages/param"
17 "github.com/openai/openai-go/v2/shared"
18)
19
20type languageModel struct {
21 provider string
22 modelID string
23 client openai.Client
24 uniqueToolCallIds bool
25 generateIDFunc LanguageModelGenerateIDFunc
26 prepareCallFunc LanguageModelPrepareCallFunc
27 mapFinishReasonFunc LanguageModelMapFinishReasonFunc
28 extraContentFunc LanguageModelExtraContentFunc
29 usageFunc LanguageModelUsageFunc
30 streamUsageFunc LanguageModelStreamUsageFunc
31 streamExtraFunc LanguageModelStreamExtraFunc
32 streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
33}
34
35type LanguageModelOption = func(*languageModel)
36
37func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
38 return func(l *languageModel) {
39 l.prepareCallFunc = fn
40 }
41}
42
43func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
44 return func(l *languageModel) {
45 l.mapFinishReasonFunc = fn
46 }
47}
48
49func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
50 return func(l *languageModel) {
51 l.extraContentFunc = fn
52 }
53}
54
55func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
56 return func(l *languageModel) {
57 l.streamExtraFunc = fn
58 }
59}
60
61func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
62 return func(l *languageModel) {
63 l.usageFunc = fn
64 }
65}
66
67func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
68 return func(l *languageModel) {
69 l.streamUsageFunc = fn
70 }
71}
72
73func WithLanguageUniqueToolCallIds() LanguageModelOption {
74 return func(l *languageModel) {
75 l.uniqueToolCallIds = true
76 }
77}
78
79func WithLanguageModelGenerateIDFunc(fn LanguageModelGenerateIDFunc) LanguageModelOption {
80 return func(l *languageModel) {
81 l.generateIDFunc = fn
82 }
83}
84
85func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
86 model := languageModel{
87 modelID: modelID,
88 provider: provider,
89 client: client,
90 generateIDFunc: defaultGenerateID,
91 prepareCallFunc: defaultPrepareLanguageModelCall,
92 mapFinishReasonFunc: defaultMapFinishReason,
93 usageFunc: defaultUsage,
94 streamUsageFunc: defaultStreamUsage,
95 streamProviderMetadataFunc: defaultStreamProviderMetadataFunc,
96 }
97
98 for _, o := range opts {
99 o(&model)
100 }
101 return model
102}
103
104type streamToolCall struct {
105 id string
106 name string
107 arguments string
108 hasFinished bool
109}
110
111// Model implements ai.LanguageModel.
112func (o languageModel) Model() string {
113 return o.modelID
114}
115
116// Provider implements ai.LanguageModel.
117func (o languageModel) Provider() string {
118 return o.provider
119}
120
121func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
122 params := &openai.ChatCompletionNewParams{}
123 messages, warnings := toPrompt(call.Prompt)
124 if call.TopK != nil {
125 warnings = append(warnings, ai.CallWarning{
126 Type: ai.CallWarningTypeUnsupportedSetting,
127 Setting: "top_k",
128 })
129 }
130
131 if call.MaxOutputTokens != nil {
132 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
133 }
134 if call.Temperature != nil {
135 params.Temperature = param.NewOpt(*call.Temperature)
136 }
137 if call.TopP != nil {
138 params.TopP = param.NewOpt(*call.TopP)
139 }
140 if call.FrequencyPenalty != nil {
141 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
142 }
143 if call.PresencePenalty != nil {
144 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
145 }
146
147 if isReasoningModel(o.modelID) {
148 // remove unsupported settings for reasoning models
149 // see https://platform.openai.com/docs/guides/reasoning#limitations
150 if call.Temperature != nil {
151 params.Temperature = param.Opt[float64]{}
152 warnings = append(warnings, ai.CallWarning{
153 Type: ai.CallWarningTypeUnsupportedSetting,
154 Setting: "temperature",
155 Details: "temperature is not supported for reasoning models",
156 })
157 }
158 if call.TopP != nil {
159 params.TopP = param.Opt[float64]{}
160 warnings = append(warnings, ai.CallWarning{
161 Type: ai.CallWarningTypeUnsupportedSetting,
162 Setting: "TopP",
163 Details: "TopP is not supported for reasoning models",
164 })
165 }
166 if call.FrequencyPenalty != nil {
167 params.FrequencyPenalty = param.Opt[float64]{}
168 warnings = append(warnings, ai.CallWarning{
169 Type: ai.CallWarningTypeUnsupportedSetting,
170 Setting: "FrequencyPenalty",
171 Details: "FrequencyPenalty is not supported for reasoning models",
172 })
173 }
174 if call.PresencePenalty != nil {
175 params.PresencePenalty = param.Opt[float64]{}
176 warnings = append(warnings, ai.CallWarning{
177 Type: ai.CallWarningTypeUnsupportedSetting,
178 Setting: "PresencePenalty",
179 Details: "PresencePenalty is not supported for reasoning models",
180 })
181 }
182
183 // reasoning models use max_completion_tokens instead of max_tokens
184 if call.MaxOutputTokens != nil {
185 if !params.MaxCompletionTokens.Valid() {
186 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
187 }
188 params.MaxTokens = param.Opt[int64]{}
189 }
190 }
191
192 // Handle search preview models
193 if isSearchPreviewModel(o.modelID) {
194 if call.Temperature != nil {
195 params.Temperature = param.Opt[float64]{}
196 warnings = append(warnings, ai.CallWarning{
197 Type: ai.CallWarningTypeUnsupportedSetting,
198 Setting: "temperature",
199 Details: "temperature is not supported for the search preview models and has been removed.",
200 })
201 }
202 }
203
204 optionsWarnings, err := o.prepareCallFunc(o, params, call)
205 if err != nil {
206 return nil, nil, err
207 }
208
209 if len(optionsWarnings) > 0 {
210 warnings = append(warnings, optionsWarnings...)
211 }
212
213 params.Messages = messages
214 params.Model = o.modelID
215
216 if len(call.Tools) > 0 {
217 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
218 params.Tools = tools
219 if toolChoice != nil {
220 params.ToolChoice = *toolChoice
221 }
222 warnings = append(warnings, toolWarnings...)
223 }
224 return params, warnings, nil
225}
226
227func (o languageModel) handleError(err error) error {
228 var apiErr *openai.Error
229 if errors.As(err, &apiErr) {
230 requestDump := apiErr.DumpRequest(true)
231 responseDump := apiErr.DumpResponse(true)
232 headers := map[string]string{}
233 for k, h := range apiErr.Response.Header {
234 v := h[len(h)-1]
235 headers[strings.ToLower(k)] = v
236 }
237 return ai.NewAPICallError(
238 apiErr.Message,
239 apiErr.Request.URL.String(),
240 string(requestDump),
241 apiErr.StatusCode,
242 headers,
243 string(responseDump),
244 apiErr,
245 false,
246 )
247 }
248 return err
249}
250
251// Generate implements ai.LanguageModel.
252func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
253 params, warnings, err := o.prepareParams(call)
254 if err != nil {
255 return nil, err
256 }
257 response, err := o.client.Chat.Completions.New(ctx, *params)
258 if err != nil {
259 return nil, o.handleError(err)
260 }
261
262 if len(response.Choices) == 0 {
263 return nil, errors.New("no response generated")
264 }
265 choice := response.Choices[0]
266 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
267 text := choice.Message.Content
268 if text != "" {
269 content = append(content, ai.TextContent{
270 Text: text,
271 })
272 }
273 if o.extraContentFunc != nil {
274 extraContent := o.extraContentFunc(choice)
275 content = append(content, extraContent...)
276 }
277 for _, tc := range choice.Message.ToolCalls {
278 toolCallID := tc.ID
279 if toolCallID == "" || o.uniqueToolCallIds {
280 toolCallID = o.generateIDFunc()
281 }
282 content = append(content, ai.ToolCallContent{
283 ProviderExecuted: false, // TODO: update when handling other tools
284 ToolCallID: toolCallID,
285 ToolName: tc.Function.Name,
286 Input: tc.Function.Arguments,
287 })
288 }
289 // Handle annotations/citations
290 for _, annotation := range choice.Message.Annotations {
291 if annotation.Type == "url_citation" {
292 content = append(content, ai.SourceContent{
293 SourceType: ai.SourceTypeURL,
294 ID: uuid.NewString(),
295 URL: annotation.URLCitation.URL,
296 Title: annotation.URLCitation.Title,
297 })
298 }
299 }
300
301 usage, providerMetadata := o.usageFunc(*response)
302
303 return &ai.Response{
304 Content: content,
305 Usage: usage,
306 FinishReason: defaultMapFinishReason(choice),
307 ProviderMetadata: ai.ProviderMetadata{
308 Name: providerMetadata,
309 },
310 Warnings: warnings,
311 }, nil
312}
313
314// Stream implements ai.LanguageModel.
315func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
316 params, warnings, err := o.prepareParams(call)
317 if err != nil {
318 return nil, err
319 }
320
321 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
322 IncludeUsage: openai.Bool(true),
323 }
324
325 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
326 isActiveText := false
327 toolCalls := make(map[int64]streamToolCall)
328
329 // Build provider metadata for streaming
330 providerMetadata := ai.ProviderMetadata{
331 Name: &ProviderMetadata{},
332 }
333 acc := openai.ChatCompletionAccumulator{}
334 extraContext := make(map[string]any)
335 var usage ai.Usage
336 return func(yield func(ai.StreamPart) bool) {
337 if len(warnings) > 0 {
338 if !yield(ai.StreamPart{
339 Type: ai.StreamPartTypeWarnings,
340 Warnings: warnings,
341 }) {
342 return
343 }
344 }
345 for stream.Next() {
346 chunk := stream.Current()
347 acc.AddChunk(chunk)
348 usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
349 if len(chunk.Choices) == 0 {
350 continue
351 }
352 for _, choice := range chunk.Choices {
353 switch {
354 case choice.Delta.Content != "":
355 if !isActiveText {
356 isActiveText = true
357 if !yield(ai.StreamPart{
358 Type: ai.StreamPartTypeTextStart,
359 ID: "0",
360 }) {
361 return
362 }
363 }
364 if !yield(ai.StreamPart{
365 Type: ai.StreamPartTypeTextDelta,
366 ID: "0",
367 Delta: choice.Delta.Content,
368 }) {
369 return
370 }
371 case len(choice.Delta.ToolCalls) > 0:
372 if isActiveText {
373 isActiveText = false
374 if !yield(ai.StreamPart{
375 Type: ai.StreamPartTypeTextEnd,
376 ID: "0",
377 }) {
378 return
379 }
380 }
381
382 for _, toolCallDelta := range choice.Delta.ToolCalls {
383 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
384 if existingToolCall.hasFinished {
385 continue
386 }
387 if toolCallDelta.Function.Arguments != "" {
388 existingToolCall.arguments += toolCallDelta.Function.Arguments
389 }
390 if !yield(ai.StreamPart{
391 Type: ai.StreamPartTypeToolInputDelta,
392 ID: existingToolCall.id,
393 Delta: toolCallDelta.Function.Arguments,
394 }) {
395 return
396 }
397 toolCalls[toolCallDelta.Index] = existingToolCall
398 if xjson.IsValid(existingToolCall.arguments) {
399 if !yield(ai.StreamPart{
400 Type: ai.StreamPartTypeToolInputEnd,
401 ID: existingToolCall.id,
402 }) {
403 return
404 }
405
406 if !yield(ai.StreamPart{
407 Type: ai.StreamPartTypeToolCall,
408 ID: existingToolCall.id,
409 ToolCallName: existingToolCall.name,
410 ToolCallInput: existingToolCall.arguments,
411 }) {
412 return
413 }
414 existingToolCall.hasFinished = true
415 toolCalls[toolCallDelta.Index] = existingToolCall
416 }
417 } else {
418 // Does not exist
419 var err error
420 if toolCallDelta.Type != "function" {
421 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
422 }
423 if toolCallDelta.ID == "" {
424 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
425 }
426 if toolCallDelta.Function.Name == "" {
427 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
428 }
429 if err != nil {
430 yield(ai.StreamPart{
431 Type: ai.StreamPartTypeError,
432 Error: o.handleError(stream.Err()),
433 })
434 return
435 }
436
437 // some providers do not send this as a unique id
438 // for some usecases in crush we need this ID to be unique.
439 // it won't change the behavior on the provider side because the
440 // provider only cares about the tool call id matching the result
441 // and in our case that will still be the case
442 if o.uniqueToolCallIds {
443 toolCallDelta.ID = o.generateIDFunc()
444 }
445
446 if !yield(ai.StreamPart{
447 Type: ai.StreamPartTypeToolInputStart,
448 ID: toolCallDelta.ID,
449 ToolCallName: toolCallDelta.Function.Name,
450 }) {
451 return
452 }
453 toolCalls[toolCallDelta.Index] = streamToolCall{
454 id: toolCallDelta.ID,
455 name: toolCallDelta.Function.Name,
456 arguments: toolCallDelta.Function.Arguments,
457 }
458
459 exTc := toolCalls[toolCallDelta.Index]
460 if exTc.arguments != "" {
461 if !yield(ai.StreamPart{
462 Type: ai.StreamPartTypeToolInputDelta,
463 ID: exTc.id,
464 Delta: exTc.arguments,
465 }) {
466 return
467 }
468 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
469 if !yield(ai.StreamPart{
470 Type: ai.StreamPartTypeToolInputEnd,
471 ID: toolCallDelta.ID,
472 }) {
473 return
474 }
475
476 if !yield(ai.StreamPart{
477 Type: ai.StreamPartTypeToolCall,
478 ID: exTc.id,
479 ToolCallName: exTc.name,
480 ToolCallInput: exTc.arguments,
481 }) {
482 return
483 }
484 exTc.hasFinished = true
485 toolCalls[toolCallDelta.Index] = exTc
486 }
487 }
488 continue
489 }
490 }
491 }
492
493 if o.streamExtraFunc != nil {
494 updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
495 if !shouldContinue {
496 return
497 }
498 extraContext = updatedContext
499 }
500 }
501
502 // Check for annotations in the delta's raw JSON
503 for _, choice := range chunk.Choices {
504 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
505 for _, annotation := range annotations {
506 if annotation.Type == "url_citation" {
507 if !yield(ai.StreamPart{
508 Type: ai.StreamPartTypeSource,
509 ID: uuid.NewString(),
510 SourceType: ai.SourceTypeURL,
511 URL: annotation.URLCitation.URL,
512 Title: annotation.URLCitation.Title,
513 }) {
514 return
515 }
516 }
517 }
518 }
519 }
520 }
521 err := stream.Err()
522 if err == nil || errors.Is(err, io.EOF) {
523 // finished
524 if isActiveText {
525 isActiveText = false
526 if !yield(ai.StreamPart{
527 Type: ai.StreamPartTypeTextEnd,
528 ID: "0",
529 }) {
530 return
531 }
532 }
533
534 if len(acc.Choices) > 0 {
535 choice := acc.Choices[0]
536 // Add logprobs if available
537 providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
538
539 // Handle annotations/citations from accumulated response
540 for _, annotation := range choice.Message.Annotations {
541 if annotation.Type == "url_citation" {
542 if !yield(ai.StreamPart{
543 Type: ai.StreamPartTypeSource,
544 ID: acc.ID,
545 SourceType: ai.SourceTypeURL,
546 URL: annotation.URLCitation.URL,
547 Title: annotation.URLCitation.Title,
548 }) {
549 return
550 }
551 }
552 }
553 }
554 finishReason := ai.FinishReasonUnknown
555 if len(acc.Choices) > 0 {
556 finishReason = o.mapFinishReasonFunc(acc.Choices[0])
557 }
558 yield(ai.StreamPart{
559 Type: ai.StreamPartTypeFinish,
560 Usage: usage,
561 FinishReason: finishReason,
562 ProviderMetadata: providerMetadata,
563 })
564 return
565 } else {
566 yield(ai.StreamPart{
567 Type: ai.StreamPartTypeError,
568 Error: o.handleError(err),
569 })
570 return
571 }
572 }, nil
573}
574
575func isReasoningModel(modelID string) bool {
576 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
577}
578
579func isSearchPreviewModel(modelID string) bool {
580 return strings.Contains(modelID, "search-preview")
581}
582
583func supportsFlexProcessing(modelID string) bool {
584 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
585}
586
587func supportsPriorityProcessing(modelID string) bool {
588 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
589 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
590 strings.HasPrefix(modelID, "o4-mini")
591}
592
593func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
594 for _, tool := range tools {
595 if tool.GetType() == ai.ToolTypeFunction {
596 ft, ok := tool.(ai.FunctionTool)
597 if !ok {
598 continue
599 }
600 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
601 OfFunction: &openai.ChatCompletionFunctionToolParam{
602 Function: shared.FunctionDefinitionParam{
603 Name: ft.Name,
604 Description: param.NewOpt(ft.Description),
605 Parameters: openai.FunctionParameters(ft.InputSchema),
606 Strict: param.NewOpt(false),
607 },
608 Type: "function",
609 },
610 })
611 continue
612 }
613
614 // TODO: handle provider tool calls
615 warnings = append(warnings, ai.CallWarning{
616 Type: ai.CallWarningTypeUnsupportedTool,
617 Tool: tool,
618 Message: "tool is not supported",
619 })
620 }
621 if toolChoice == nil {
622 return openAiTools, openAiToolChoice, warnings
623 }
624
625 switch *toolChoice {
626 case ai.ToolChoiceAuto:
627 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
628 OfAuto: param.NewOpt("auto"),
629 }
630 case ai.ToolChoiceNone:
631 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
632 OfAuto: param.NewOpt("none"),
633 }
634 default:
635 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
636 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
637 Type: "function",
638 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
639 Name: string(*toolChoice),
640 },
641 },
642 }
643 }
644 return openAiTools, openAiToolChoice, warnings
645}
646
647func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
648 var messages []openai.ChatCompletionMessageParamUnion
649 var warnings []ai.CallWarning
650 for _, msg := range prompt {
651 switch msg.Role {
652 case ai.MessageRoleSystem:
653 var systemPromptParts []string
654 for _, c := range msg.Content {
655 if c.GetType() != ai.ContentTypeText {
656 warnings = append(warnings, ai.CallWarning{
657 Type: ai.CallWarningTypeOther,
658 Message: "system prompt can only have text content",
659 })
660 continue
661 }
662 textPart, ok := ai.AsContentType[ai.TextPart](c)
663 if !ok {
664 warnings = append(warnings, ai.CallWarning{
665 Type: ai.CallWarningTypeOther,
666 Message: "system prompt text part does not have the right type",
667 })
668 continue
669 }
670 text := textPart.Text
671 if strings.TrimSpace(text) != "" {
672 systemPromptParts = append(systemPromptParts, textPart.Text)
673 }
674 }
675 if len(systemPromptParts) == 0 {
676 warnings = append(warnings, ai.CallWarning{
677 Type: ai.CallWarningTypeOther,
678 Message: "system prompt has no text parts",
679 })
680 continue
681 }
682 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
683 case ai.MessageRoleUser:
684 // simple user message just text content
685 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
686 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
687 if !ok {
688 warnings = append(warnings, ai.CallWarning{
689 Type: ai.CallWarningTypeOther,
690 Message: "user message text part does not have the right type",
691 })
692 continue
693 }
694 messages = append(messages, openai.UserMessage(textPart.Text))
695 continue
696 }
697 // text content and attachments
698 // for now we only support image content later we need to check
699 // TODO: add the supported media types to the language model so we
700 // can use that to validate the data here.
701 var content []openai.ChatCompletionContentPartUnionParam
702 for _, c := range msg.Content {
703 switch c.GetType() {
704 case ai.ContentTypeText:
705 textPart, ok := ai.AsContentType[ai.TextPart](c)
706 if !ok {
707 warnings = append(warnings, ai.CallWarning{
708 Type: ai.CallWarningTypeOther,
709 Message: "user message text part does not have the right type",
710 })
711 continue
712 }
713 content = append(content, openai.ChatCompletionContentPartUnionParam{
714 OfText: &openai.ChatCompletionContentPartTextParam{
715 Text: textPart.Text,
716 },
717 })
718 case ai.ContentTypeFile:
719 filePart, ok := ai.AsContentType[ai.FilePart](c)
720 if !ok {
721 warnings = append(warnings, ai.CallWarning{
722 Type: ai.CallWarningTypeOther,
723 Message: "user message file part does not have the right type",
724 })
725 continue
726 }
727
728 switch {
729 case strings.HasPrefix(filePart.MediaType, "image/"):
730 // Handle image files
731 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
732 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
733 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
734
735 // Check for provider-specific options like image detail
736 if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
737 if detail, ok := providerOptions.(*ProviderFileOptions); ok {
738 imageURL.Detail = detail.ImageDetail
739 }
740 }
741
742 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
743 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
744
745 case filePart.MediaType == "audio/wav":
746 // Handle WAV audio files
747 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
748 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
749 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
750 Data: base64Encoded,
751 Format: "wav",
752 },
753 }
754 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
755
756 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
757 // Handle MP3 audio files
758 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
759 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
760 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
761 Data: base64Encoded,
762 Format: "mp3",
763 },
764 }
765 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
766
767 case filePart.MediaType == "application/pdf":
768 // Handle PDF files
769 dataStr := string(filePart.Data)
770
771 // Check if data looks like a file ID (starts with "file-")
772 if strings.HasPrefix(dataStr, "file-") {
773 fileBlock := openai.ChatCompletionContentPartFileParam{
774 File: openai.ChatCompletionContentPartFileFileParam{
775 FileID: param.NewOpt(dataStr),
776 },
777 }
778 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
779 } else {
780 // Handle as base64 data
781 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
782 data := "data:application/pdf;base64," + base64Encoded
783
784 filename := filePart.Filename
785 if filename == "" {
786 // Generate default filename based on content index
787 filename = fmt.Sprintf("part-%d.pdf", len(content))
788 }
789
790 fileBlock := openai.ChatCompletionContentPartFileParam{
791 File: openai.ChatCompletionContentPartFileFileParam{
792 Filename: param.NewOpt(filename),
793 FileData: param.NewOpt(data),
794 },
795 }
796 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
797 }
798
799 default:
800 warnings = append(warnings, ai.CallWarning{
801 Type: ai.CallWarningTypeOther,
802 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
803 })
804 }
805 }
806 }
807 messages = append(messages, openai.UserMessage(content))
808 case ai.MessageRoleAssistant:
809 // simple assistant message just text content
810 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
811 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
812 if !ok {
813 warnings = append(warnings, ai.CallWarning{
814 Type: ai.CallWarningTypeOther,
815 Message: "assistant message text part does not have the right type",
816 })
817 continue
818 }
819 messages = append(messages, openai.AssistantMessage(textPart.Text))
820 continue
821 }
822 assistantMsg := openai.ChatCompletionAssistantMessageParam{
823 Role: "assistant",
824 }
825 for _, c := range msg.Content {
826 switch c.GetType() {
827 case ai.ContentTypeText:
828 textPart, ok := ai.AsContentType[ai.TextPart](c)
829 if !ok {
830 warnings = append(warnings, ai.CallWarning{
831 Type: ai.CallWarningTypeOther,
832 Message: "assistant message text part does not have the right type",
833 })
834 continue
835 }
836 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
837 OfString: param.NewOpt(textPart.Text),
838 }
839 case ai.ContentTypeToolCall:
840 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
841 if !ok {
842 warnings = append(warnings, ai.CallWarning{
843 Type: ai.CallWarningTypeOther,
844 Message: "assistant message tool part does not have the right type",
845 })
846 continue
847 }
848 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
849 openai.ChatCompletionMessageToolCallUnionParam{
850 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
851 ID: toolCallPart.ToolCallID,
852 Type: "function",
853 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
854 Name: toolCallPart.ToolName,
855 Arguments: toolCallPart.Input,
856 },
857 },
858 })
859 }
860 }
861 messages = append(messages, openai.ChatCompletionMessageParamUnion{
862 OfAssistant: &assistantMsg,
863 })
864 case ai.MessageRoleTool:
865 for _, c := range msg.Content {
866 if c.GetType() != ai.ContentTypeToolResult {
867 warnings = append(warnings, ai.CallWarning{
868 Type: ai.CallWarningTypeOther,
869 Message: "tool message can only have tool result content",
870 })
871 continue
872 }
873
874 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
875 if !ok {
876 warnings = append(warnings, ai.CallWarning{
877 Type: ai.CallWarningTypeOther,
878 Message: "tool message result part does not have the right type",
879 })
880 continue
881 }
882
883 switch toolResultPart.Output.GetType() {
884 case ai.ToolResultContentTypeText:
885 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
886 if !ok {
887 warnings = append(warnings, ai.CallWarning{
888 Type: ai.CallWarningTypeOther,
889 Message: "tool result output does not have the right type",
890 })
891 continue
892 }
893 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
894 case ai.ToolResultContentTypeError:
895 // TODO: check if better handling is needed
896 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
897 if !ok {
898 warnings = append(warnings, ai.CallWarning{
899 Type: ai.CallWarningTypeOther,
900 Message: "tool result output does not have the right type",
901 })
902 continue
903 }
904 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
905 }
906 }
907 }
908 }
909 return messages, warnings
910}
911
912// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
913func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
914 var annotations []openai.ChatCompletionMessageAnnotation
915
916 // Parse the raw JSON to extract annotations
917 var deltaData map[string]any
918 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
919 return annotations
920 }
921
922 // Check if annotations exist in the delta
923 if annotationsData, ok := deltaData["annotations"].([]any); ok {
924 for _, annotationData := range annotationsData {
925 if annotationMap, ok := annotationData.(map[string]any); ok {
926 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
927 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
928 annotation := openai.ChatCompletionMessageAnnotation{
929 Type: "url_citation",
930 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
931 URL: urlCitationData["url"].(string),
932 Title: urlCitationData["title"].(string),
933 },
934 }
935 annotations = append(annotations, annotation)
936 }
937 }
938 }
939 }
940 }
941
942 return annotations
943}