language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"strings"
 11
 12	"github.com/charmbracelet/fantasy/ai"
 13	xjson "github.com/charmbracelet/x/json"
 14	"github.com/google/uuid"
 15	"github.com/openai/openai-go/v2"
 16	"github.com/openai/openai-go/v2/packages/param"
 17	"github.com/openai/openai-go/v2/shared"
 18)
 19
 20type languageModel struct {
 21	provider        string
 22	modelID         string
 23	client          openai.Client
 24	prepareCallFunc PrepareLanguageModelCallFunc
 25}
 26
 27type LanguageModelOption = func(*languageModel)
 28
 29func WithPrepareLanguageModelCallFunc(fn PrepareLanguageModelCallFunc) LanguageModelOption {
 30	return func(l *languageModel) {
 31		l.prepareCallFunc = fn
 32	}
 33}
 34
 35func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
 36	model := languageModel{
 37		modelID:         modelID,
 38		provider:        provider,
 39		client:          client,
 40		prepareCallFunc: defaultPrepareLanguageModelCall,
 41	}
 42
 43	for _, o := range opts {
 44		o(&model)
 45	}
 46	return model
 47}
 48
 49type streamToolCall struct {
 50	id          string
 51	name        string
 52	arguments   string
 53	hasFinished bool
 54}
 55
 56// Model implements ai.LanguageModel.
 57func (o languageModel) Model() string {
 58	return o.modelID
 59}
 60
 61// Provider implements ai.LanguageModel.
 62func (o languageModel) Provider() string {
 63	return o.provider
 64}
 65
 66func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 67	params := &openai.ChatCompletionNewParams{}
 68	messages, warnings := toPrompt(call.Prompt)
 69	if call.TopK != nil {
 70		warnings = append(warnings, ai.CallWarning{
 71			Type:    ai.CallWarningTypeUnsupportedSetting,
 72			Setting: "top_k",
 73		})
 74	}
 75
 76	if call.MaxOutputTokens != nil {
 77		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 78	}
 79	if call.Temperature != nil {
 80		params.Temperature = param.NewOpt(*call.Temperature)
 81	}
 82	if call.TopP != nil {
 83		params.TopP = param.NewOpt(*call.TopP)
 84	}
 85	if call.FrequencyPenalty != nil {
 86		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 87	}
 88	if call.PresencePenalty != nil {
 89		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 90	}
 91
 92	if isReasoningModel(o.modelID) {
 93		// remove unsupported settings for reasoning models
 94		// see https://platform.openai.com/docs/guides/reasoning#limitations
 95		if call.Temperature != nil {
 96			params.Temperature = param.Opt[float64]{}
 97			warnings = append(warnings, ai.CallWarning{
 98				Type:    ai.CallWarningTypeUnsupportedSetting,
 99				Setting: "temperature",
100				Details: "temperature is not supported for reasoning models",
101			})
102		}
103		if call.TopP != nil {
104			params.TopP = param.Opt[float64]{}
105			warnings = append(warnings, ai.CallWarning{
106				Type:    ai.CallWarningTypeUnsupportedSetting,
107				Setting: "TopP",
108				Details: "TopP is not supported for reasoning models",
109			})
110		}
111		if call.FrequencyPenalty != nil {
112			params.FrequencyPenalty = param.Opt[float64]{}
113			warnings = append(warnings, ai.CallWarning{
114				Type:    ai.CallWarningTypeUnsupportedSetting,
115				Setting: "FrequencyPenalty",
116				Details: "FrequencyPenalty is not supported for reasoning models",
117			})
118		}
119		if call.PresencePenalty != nil {
120			params.PresencePenalty = param.Opt[float64]{}
121			warnings = append(warnings, ai.CallWarning{
122				Type:    ai.CallWarningTypeUnsupportedSetting,
123				Setting: "PresencePenalty",
124				Details: "PresencePenalty is not supported for reasoning models",
125			})
126		}
127
128		// reasoning models use max_completion_tokens instead of max_tokens
129		if call.MaxOutputTokens != nil {
130			if !params.MaxCompletionTokens.Valid() {
131				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
132			}
133			params.MaxTokens = param.Opt[int64]{}
134		}
135	}
136
137	// Handle search preview models
138	if isSearchPreviewModel(o.modelID) {
139		if call.Temperature != nil {
140			params.Temperature = param.Opt[float64]{}
141			warnings = append(warnings, ai.CallWarning{
142				Type:    ai.CallWarningTypeUnsupportedSetting,
143				Setting: "temperature",
144				Details: "temperature is not supported for the search preview models and has been removed.",
145			})
146		}
147	}
148
149	optionsWarnings, err := o.prepareCallFunc(o, params, call)
150	if err != nil {
151		return nil, nil, err
152	}
153
154	if len(optionsWarnings) > 0 {
155		warnings = append(warnings, optionsWarnings...)
156	}
157
158	params.Messages = messages
159	params.Model = o.modelID
160
161	if len(call.Tools) > 0 {
162		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
163		params.Tools = tools
164		if toolChoice != nil {
165			params.ToolChoice = *toolChoice
166		}
167		warnings = append(warnings, toolWarnings...)
168	}
169	return params, warnings, nil
170}
171
172func (o languageModel) handleError(err error) error {
173	var apiErr *openai.Error
174	if errors.As(err, &apiErr) {
175		requestDump := apiErr.DumpRequest(true)
176		responseDump := apiErr.DumpResponse(true)
177		headers := map[string]string{}
178		for k, h := range apiErr.Response.Header {
179			v := h[len(h)-1]
180			headers[strings.ToLower(k)] = v
181		}
182		return ai.NewAPICallError(
183			apiErr.Message,
184			apiErr.Request.URL.String(),
185			string(requestDump),
186			apiErr.StatusCode,
187			headers,
188			string(responseDump),
189			apiErr,
190			false,
191		)
192	}
193	return err
194}
195
196// Generate implements ai.LanguageModel.
197func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
198	params, warnings, err := o.prepareParams(call)
199	if err != nil {
200		return nil, err
201	}
202	response, err := o.client.Chat.Completions.New(ctx, *params)
203	if err != nil {
204		return nil, o.handleError(err)
205	}
206
207	if len(response.Choices) == 0 {
208		return nil, errors.New("no response generated")
209	}
210	choice := response.Choices[0]
211	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
212	text := choice.Message.Content
213	if text != "" {
214		content = append(content, ai.TextContent{
215			Text: text,
216		})
217	}
218
219	for _, tc := range choice.Message.ToolCalls {
220		toolCallID := tc.ID
221		if toolCallID == "" {
222			toolCallID = uuid.NewString()
223		}
224		content = append(content, ai.ToolCallContent{
225			ProviderExecuted: false, // TODO: update when handling other tools
226			ToolCallID:       toolCallID,
227			ToolName:         tc.Function.Name,
228			Input:            tc.Function.Arguments,
229		})
230	}
231	// Handle annotations/citations
232	for _, annotation := range choice.Message.Annotations {
233		if annotation.Type == "url_citation" {
234			content = append(content, ai.SourceContent{
235				SourceType: ai.SourceTypeURL,
236				ID:         uuid.NewString(),
237				URL:        annotation.URLCitation.URL,
238				Title:      annotation.URLCitation.Title,
239			})
240		}
241	}
242
243	completionTokenDetails := response.Usage.CompletionTokensDetails
244	promptTokenDetails := response.Usage.PromptTokensDetails
245
246	// Build provider metadata
247	providerMetadata := &ProviderMetadata{}
248	// Add logprobs if available
249	if len(choice.Logprobs.Content) > 0 {
250		providerMetadata.Logprobs = choice.Logprobs.Content
251	}
252
253	// Add prediction tokens if available
254	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
255		if completionTokenDetails.AcceptedPredictionTokens > 0 {
256			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
257		}
258		if completionTokenDetails.RejectedPredictionTokens > 0 {
259			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
260		}
261	}
262
263	return &ai.Response{
264		Content: content,
265		Usage: ai.Usage{
266			InputTokens:     response.Usage.PromptTokens,
267			OutputTokens:    response.Usage.CompletionTokens,
268			TotalTokens:     response.Usage.TotalTokens,
269			ReasoningTokens: completionTokenDetails.ReasoningTokens,
270			CacheReadTokens: promptTokenDetails.CachedTokens,
271		},
272		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
273		ProviderMetadata: ai.ProviderMetadata{
274			Name: providerMetadata,
275		},
276		Warnings: warnings,
277	}, nil
278}
279
280// Stream implements ai.LanguageModel.
281func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
282	params, warnings, err := o.prepareParams(call)
283	if err != nil {
284		return nil, err
285	}
286
287	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
288		IncludeUsage: openai.Bool(true),
289	}
290
291	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
292	isActiveText := false
293	toolCalls := make(map[int64]streamToolCall)
294
295	// Build provider metadata for streaming
296	streamProviderMetadata := &ProviderMetadata{}
297	acc := openai.ChatCompletionAccumulator{}
298	var usage ai.Usage
299	return func(yield func(ai.StreamPart) bool) {
300		if len(warnings) > 0 {
301			if !yield(ai.StreamPart{
302				Type:     ai.StreamPartTypeWarnings,
303				Warnings: warnings,
304			}) {
305				return
306			}
307		}
308		for stream.Next() {
309			chunk := stream.Current()
310			acc.AddChunk(chunk)
311			if chunk.Usage.TotalTokens > 0 {
312				// we do this here because the acc does not add prompt details
313				completionTokenDetails := chunk.Usage.CompletionTokensDetails
314				promptTokenDetails := chunk.Usage.PromptTokensDetails
315				usage = ai.Usage{
316					InputTokens:     chunk.Usage.PromptTokens,
317					OutputTokens:    chunk.Usage.CompletionTokens,
318					TotalTokens:     chunk.Usage.TotalTokens,
319					ReasoningTokens: completionTokenDetails.ReasoningTokens,
320					CacheReadTokens: promptTokenDetails.CachedTokens,
321				}
322
323				// Add prediction tokens if available
324				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
325					if completionTokenDetails.AcceptedPredictionTokens > 0 {
326						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
327					}
328					if completionTokenDetails.RejectedPredictionTokens > 0 {
329						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
330					}
331				}
332			}
333			if len(chunk.Choices) == 0 {
334				continue
335			}
336			for _, choice := range chunk.Choices {
337				switch {
338				case choice.Delta.Content != "":
339					if !isActiveText {
340						isActiveText = true
341						if !yield(ai.StreamPart{
342							Type: ai.StreamPartTypeTextStart,
343							ID:   "0",
344						}) {
345							return
346						}
347					}
348					if !yield(ai.StreamPart{
349						Type:  ai.StreamPartTypeTextDelta,
350						ID:    "0",
351						Delta: choice.Delta.Content,
352					}) {
353						return
354					}
355				case len(choice.Delta.ToolCalls) > 0:
356					if isActiveText {
357						isActiveText = false
358						if !yield(ai.StreamPart{
359							Type: ai.StreamPartTypeTextEnd,
360							ID:   "0",
361						}) {
362							return
363						}
364					}
365
366					for _, toolCallDelta := range choice.Delta.ToolCalls {
367						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
368							if existingToolCall.hasFinished {
369								continue
370							}
371							if toolCallDelta.Function.Arguments != "" {
372								existingToolCall.arguments += toolCallDelta.Function.Arguments
373							}
374							if !yield(ai.StreamPart{
375								Type:  ai.StreamPartTypeToolInputDelta,
376								ID:    existingToolCall.id,
377								Delta: toolCallDelta.Function.Arguments,
378							}) {
379								return
380							}
381							toolCalls[toolCallDelta.Index] = existingToolCall
382							if xjson.IsValid(existingToolCall.arguments) {
383								if !yield(ai.StreamPart{
384									Type: ai.StreamPartTypeToolInputEnd,
385									ID:   existingToolCall.id,
386								}) {
387									return
388								}
389
390								if !yield(ai.StreamPart{
391									Type:          ai.StreamPartTypeToolCall,
392									ID:            existingToolCall.id,
393									ToolCallName:  existingToolCall.name,
394									ToolCallInput: existingToolCall.arguments,
395								}) {
396									return
397								}
398								existingToolCall.hasFinished = true
399								toolCalls[toolCallDelta.Index] = existingToolCall
400							}
401						} else {
402							// Does not exist
403							var err error
404							if toolCallDelta.Type != "function" {
405								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
406							}
407							if toolCallDelta.ID == "" {
408								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
409							}
410							if toolCallDelta.Function.Name == "" {
411								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
412							}
413							if err != nil {
414								yield(ai.StreamPart{
415									Type:  ai.StreamPartTypeError,
416									Error: o.handleError(stream.Err()),
417								})
418								return
419							}
420
421							if !yield(ai.StreamPart{
422								Type:         ai.StreamPartTypeToolInputStart,
423								ID:           toolCallDelta.ID,
424								ToolCallName: toolCallDelta.Function.Name,
425							}) {
426								return
427							}
428							toolCalls[toolCallDelta.Index] = streamToolCall{
429								id:        toolCallDelta.ID,
430								name:      toolCallDelta.Function.Name,
431								arguments: toolCallDelta.Function.Arguments,
432							}
433
434							exTc := toolCalls[toolCallDelta.Index]
435							if exTc.arguments != "" {
436								if !yield(ai.StreamPart{
437									Type:  ai.StreamPartTypeToolInputDelta,
438									ID:    exTc.id,
439									Delta: exTc.arguments,
440								}) {
441									return
442								}
443								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
444									if !yield(ai.StreamPart{
445										Type: ai.StreamPartTypeToolInputEnd,
446										ID:   toolCallDelta.ID,
447									}) {
448										return
449									}
450
451									if !yield(ai.StreamPart{
452										Type:          ai.StreamPartTypeToolCall,
453										ID:            exTc.id,
454										ToolCallName:  exTc.name,
455										ToolCallInput: exTc.arguments,
456									}) {
457										return
458									}
459									exTc.hasFinished = true
460									toolCalls[toolCallDelta.Index] = exTc
461								}
462							}
463							continue
464						}
465					}
466				}
467			}
468
469			// Check for annotations in the delta's raw JSON
470			for _, choice := range chunk.Choices {
471				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
472					for _, annotation := range annotations {
473						if annotation.Type == "url_citation" {
474							if !yield(ai.StreamPart{
475								Type:       ai.StreamPartTypeSource,
476								ID:         uuid.NewString(),
477								SourceType: ai.SourceTypeURL,
478								URL:        annotation.URLCitation.URL,
479								Title:      annotation.URLCitation.Title,
480							}) {
481								return
482							}
483						}
484					}
485				}
486			}
487		}
488		err := stream.Err()
489		if err == nil || errors.Is(err, io.EOF) {
490			// finished
491			if isActiveText {
492				isActiveText = false
493				if !yield(ai.StreamPart{
494					Type: ai.StreamPartTypeTextEnd,
495					ID:   "0",
496				}) {
497					return
498				}
499			}
500
501			// Add logprobs if available
502			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
503				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
504			}
505
506			// Handle annotations/citations from accumulated response
507			if len(acc.Choices) > 0 {
508				for _, annotation := range acc.Choices[0].Message.Annotations {
509					if annotation.Type == "url_citation" {
510						if !yield(ai.StreamPart{
511							Type:       ai.StreamPartTypeSource,
512							ID:         acc.ID,
513							SourceType: ai.SourceTypeURL,
514							URL:        annotation.URLCitation.URL,
515							Title:      annotation.URLCitation.Title,
516						}) {
517							return
518						}
519					}
520				}
521			}
522
523			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
524			yield(ai.StreamPart{
525				Type:         ai.StreamPartTypeFinish,
526				Usage:        usage,
527				FinishReason: finishReason,
528				ProviderMetadata: ai.ProviderMetadata{
529					Name: streamProviderMetadata,
530				},
531			})
532			return
533		} else {
534			yield(ai.StreamPart{
535				Type:  ai.StreamPartTypeError,
536				Error: o.handleError(err),
537			})
538			return
539		}
540	}, nil
541}
542
543func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
544	switch finishReason {
545	case "stop":
546		return ai.FinishReasonStop
547	case "length":
548		return ai.FinishReasonLength
549	case "content_filter":
550		return ai.FinishReasonContentFilter
551	case "function_call", "tool_calls":
552		return ai.FinishReasonToolCalls
553	default:
554		return ai.FinishReasonUnknown
555	}
556}
557
558func isReasoningModel(modelID string) bool {
559	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
560}
561
562func isSearchPreviewModel(modelID string) bool {
563	return strings.Contains(modelID, "search-preview")
564}
565
566func supportsFlexProcessing(modelID string) bool {
567	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
568}
569
570func supportsPriorityProcessing(modelID string) bool {
571	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
572		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
573		strings.HasPrefix(modelID, "o4-mini")
574}
575
576func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
577	for _, tool := range tools {
578		if tool.GetType() == ai.ToolTypeFunction {
579			ft, ok := tool.(ai.FunctionTool)
580			if !ok {
581				continue
582			}
583			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
584				OfFunction: &openai.ChatCompletionFunctionToolParam{
585					Function: shared.FunctionDefinitionParam{
586						Name:        ft.Name,
587						Description: param.NewOpt(ft.Description),
588						Parameters:  openai.FunctionParameters(ft.InputSchema),
589						Strict:      param.NewOpt(false),
590					},
591					Type: "function",
592				},
593			})
594			continue
595		}
596
597		// TODO: handle provider tool calls
598		warnings = append(warnings, ai.CallWarning{
599			Type:    ai.CallWarningTypeUnsupportedTool,
600			Tool:    tool,
601			Message: "tool is not supported",
602		})
603	}
604	if toolChoice == nil {
605		return openAiTools, openAiToolChoice, warnings
606	}
607
608	switch *toolChoice {
609	case ai.ToolChoiceAuto:
610		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
611			OfAuto: param.NewOpt("auto"),
612		}
613	case ai.ToolChoiceNone:
614		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
615			OfAuto: param.NewOpt("none"),
616		}
617	default:
618		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
619			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
620				Type: "function",
621				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
622					Name: string(*toolChoice),
623				},
624			},
625		}
626	}
627	return openAiTools, openAiToolChoice, warnings
628}
629
630func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
631	var messages []openai.ChatCompletionMessageParamUnion
632	var warnings []ai.CallWarning
633	for _, msg := range prompt {
634		switch msg.Role {
635		case ai.MessageRoleSystem:
636			var systemPromptParts []string
637			for _, c := range msg.Content {
638				if c.GetType() != ai.ContentTypeText {
639					warnings = append(warnings, ai.CallWarning{
640						Type:    ai.CallWarningTypeOther,
641						Message: "system prompt can only have text content",
642					})
643					continue
644				}
645				textPart, ok := ai.AsContentType[ai.TextPart](c)
646				if !ok {
647					warnings = append(warnings, ai.CallWarning{
648						Type:    ai.CallWarningTypeOther,
649						Message: "system prompt text part does not have the right type",
650					})
651					continue
652				}
653				text := textPart.Text
654				if strings.TrimSpace(text) != "" {
655					systemPromptParts = append(systemPromptParts, textPart.Text)
656				}
657			}
658			if len(systemPromptParts) == 0 {
659				warnings = append(warnings, ai.CallWarning{
660					Type:    ai.CallWarningTypeOther,
661					Message: "system prompt has no text parts",
662				})
663				continue
664			}
665			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
666		case ai.MessageRoleUser:
667			// simple user message just text content
668			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
669				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
670				if !ok {
671					warnings = append(warnings, ai.CallWarning{
672						Type:    ai.CallWarningTypeOther,
673						Message: "user message text part does not have the right type",
674					})
675					continue
676				}
677				messages = append(messages, openai.UserMessage(textPart.Text))
678				continue
679			}
680			// text content and attachments
681			// for now we only support image content later we need to check
682			// TODO: add the supported media types to the language model so we
683			//  can use that to validate the data here.
684			var content []openai.ChatCompletionContentPartUnionParam
685			for _, c := range msg.Content {
686				switch c.GetType() {
687				case ai.ContentTypeText:
688					textPart, ok := ai.AsContentType[ai.TextPart](c)
689					if !ok {
690						warnings = append(warnings, ai.CallWarning{
691							Type:    ai.CallWarningTypeOther,
692							Message: "user message text part does not have the right type",
693						})
694						continue
695					}
696					content = append(content, openai.ChatCompletionContentPartUnionParam{
697						OfText: &openai.ChatCompletionContentPartTextParam{
698							Text: textPart.Text,
699						},
700					})
701				case ai.ContentTypeFile:
702					filePart, ok := ai.AsContentType[ai.FilePart](c)
703					if !ok {
704						warnings = append(warnings, ai.CallWarning{
705							Type:    ai.CallWarningTypeOther,
706							Message: "user message file part does not have the right type",
707						})
708						continue
709					}
710
711					switch {
712					case strings.HasPrefix(filePart.MediaType, "image/"):
713						// Handle image files
714						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
715						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
716						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
717
718						// Check for provider-specific options like image detail
719						if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
720							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
721								imageURL.Detail = detail.ImageDetail
722							}
723						}
724
725						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
726						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
727
728					case filePart.MediaType == "audio/wav":
729						// Handle WAV audio files
730						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
731						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
732							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
733								Data:   base64Encoded,
734								Format: "wav",
735							},
736						}
737						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
738
739					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
740						// Handle MP3 audio files
741						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
742						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
743							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
744								Data:   base64Encoded,
745								Format: "mp3",
746							},
747						}
748						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
749
750					case filePart.MediaType == "application/pdf":
751						// Handle PDF files
752						dataStr := string(filePart.Data)
753
754						// Check if data looks like a file ID (starts with "file-")
755						if strings.HasPrefix(dataStr, "file-") {
756							fileBlock := openai.ChatCompletionContentPartFileParam{
757								File: openai.ChatCompletionContentPartFileFileParam{
758									FileID: param.NewOpt(dataStr),
759								},
760							}
761							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
762						} else {
763							// Handle as base64 data
764							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
765							data := "data:application/pdf;base64," + base64Encoded
766
767							filename := filePart.Filename
768							if filename == "" {
769								// Generate default filename based on content index
770								filename = fmt.Sprintf("part-%d.pdf", len(content))
771							}
772
773							fileBlock := openai.ChatCompletionContentPartFileParam{
774								File: openai.ChatCompletionContentPartFileFileParam{
775									Filename: param.NewOpt(filename),
776									FileData: param.NewOpt(data),
777								},
778							}
779							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
780						}
781
782					default:
783						warnings = append(warnings, ai.CallWarning{
784							Type:    ai.CallWarningTypeOther,
785							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
786						})
787					}
788				}
789			}
790			messages = append(messages, openai.UserMessage(content))
791		case ai.MessageRoleAssistant:
792			// simple assistant message just text content
793			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
794				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
795				if !ok {
796					warnings = append(warnings, ai.CallWarning{
797						Type:    ai.CallWarningTypeOther,
798						Message: "assistant message text part does not have the right type",
799					})
800					continue
801				}
802				messages = append(messages, openai.AssistantMessage(textPart.Text))
803				continue
804			}
805			assistantMsg := openai.ChatCompletionAssistantMessageParam{
806				Role: "assistant",
807			}
808			for _, c := range msg.Content {
809				switch c.GetType() {
810				case ai.ContentTypeText:
811					textPart, ok := ai.AsContentType[ai.TextPart](c)
812					if !ok {
813						warnings = append(warnings, ai.CallWarning{
814							Type:    ai.CallWarningTypeOther,
815							Message: "assistant message text part does not have the right type",
816						})
817						continue
818					}
819					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
820						OfString: param.NewOpt(textPart.Text),
821					}
822				case ai.ContentTypeToolCall:
823					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
824					if !ok {
825						warnings = append(warnings, ai.CallWarning{
826							Type:    ai.CallWarningTypeOther,
827							Message: "assistant message tool part does not have the right type",
828						})
829						continue
830					}
831					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
832						openai.ChatCompletionMessageToolCallUnionParam{
833							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
834								ID:   toolCallPart.ToolCallID,
835								Type: "function",
836								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
837									Name:      toolCallPart.ToolName,
838									Arguments: toolCallPart.Input,
839								},
840							},
841						})
842				}
843			}
844			messages = append(messages, openai.ChatCompletionMessageParamUnion{
845				OfAssistant: &assistantMsg,
846			})
847		case ai.MessageRoleTool:
848			for _, c := range msg.Content {
849				if c.GetType() != ai.ContentTypeToolResult {
850					warnings = append(warnings, ai.CallWarning{
851						Type:    ai.CallWarningTypeOther,
852						Message: "tool message can only have tool result content",
853					})
854					continue
855				}
856
857				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
858				if !ok {
859					warnings = append(warnings, ai.CallWarning{
860						Type:    ai.CallWarningTypeOther,
861						Message: "tool message result part does not have the right type",
862					})
863					continue
864				}
865
866				switch toolResultPart.Output.GetType() {
867				case ai.ToolResultContentTypeText:
868					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
869					if !ok {
870						warnings = append(warnings, ai.CallWarning{
871							Type:    ai.CallWarningTypeOther,
872							Message: "tool result output does not have the right type",
873						})
874						continue
875					}
876					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
877				case ai.ToolResultContentTypeError:
878					// TODO: check if better handling is needed
879					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
880					if !ok {
881						warnings = append(warnings, ai.CallWarning{
882							Type:    ai.CallWarningTypeOther,
883							Message: "tool result output does not have the right type",
884						})
885						continue
886					}
887					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
888				}
889			}
890		}
891	}
892	return messages, warnings
893}
894
895// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
896func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
897	var annotations []openai.ChatCompletionMessageAnnotation
898
899	// Parse the raw JSON to extract annotations
900	var deltaData map[string]any
901	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
902		return annotations
903	}
904
905	// Check if annotations exist in the delta
906	if annotationsData, ok := deltaData["annotations"].([]any); ok {
907		for _, annotationData := range annotationsData {
908			if annotationMap, ok := annotationData.(map[string]any); ok {
909				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
910					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
911						annotation := openai.ChatCompletionMessageAnnotation{
912							Type: "url_citation",
913							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
914								URL:   urlCitationData["url"].(string),
915								Title: urlCitationData["title"].(string),
916							},
917						}
918						annotations = append(annotations, annotation)
919					}
920				}
921			}
922		}
923	}
924
925	return annotations
926}