language_model.go

  1package openai
  2
  3import (
  4	"context"
  5	"encoding/base64"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"strings"
 11
 12	"github.com/charmbracelet/fantasy/ai"
 13	xjson "github.com/charmbracelet/x/json"
 14	"github.com/google/uuid"
 15	"github.com/openai/openai-go/v2"
 16	"github.com/openai/openai-go/v2/packages/param"
 17	"github.com/openai/openai-go/v2/shared"
 18)
 19
 20type languageModel struct {
 21	provider                   string
 22	modelID                    string
 23	client                     openai.Client
 24	uniqueToolCallIds          bool
 25	generateIDFunc             LanguageModelGenerateIDFunc
 26	prepareCallFunc            LanguageModelPrepareCallFunc
 27	mapFinishReasonFunc        LanguageModelMapFinishReasonFunc
 28	extraContentFunc           LanguageModelExtraContentFunc
 29	usageFunc                  LanguageModelUsageFunc
 30	streamUsageFunc            LanguageModelStreamUsageFunc
 31	streamExtraFunc            LanguageModelStreamExtraFunc
 32	streamProviderMetadataFunc LanguageModelStreamProviderMetadataFunc
 33}
 34
 35type LanguageModelOption = func(*languageModel)
 36
 37func WithLanguageModelPrepareCallFunc(fn LanguageModelPrepareCallFunc) LanguageModelOption {
 38	return func(l *languageModel) {
 39		l.prepareCallFunc = fn
 40	}
 41}
 42
 43func WithLanguageModelMapFinishReasonFunc(fn LanguageModelMapFinishReasonFunc) LanguageModelOption {
 44	return func(l *languageModel) {
 45		l.mapFinishReasonFunc = fn
 46	}
 47}
 48
 49func WithLanguageModelExtraContentFunc(fn LanguageModelExtraContentFunc) LanguageModelOption {
 50	return func(l *languageModel) {
 51		l.extraContentFunc = fn
 52	}
 53}
 54
 55func WithLanguageModelStreamExtraFunc(fn LanguageModelStreamExtraFunc) LanguageModelOption {
 56	return func(l *languageModel) {
 57		l.streamExtraFunc = fn
 58	}
 59}
 60
 61func WithLanguageModelUsageFunc(fn LanguageModelUsageFunc) LanguageModelOption {
 62	return func(l *languageModel) {
 63		l.usageFunc = fn
 64	}
 65}
 66
 67func WithLanguageModelStreamUsageFunc(fn LanguageModelStreamUsageFunc) LanguageModelOption {
 68	return func(l *languageModel) {
 69		l.streamUsageFunc = fn
 70	}
 71}
 72
 73func WithLanguageUniqueToolCallIds() LanguageModelOption {
 74	return func(l *languageModel) {
 75		l.uniqueToolCallIds = true
 76	}
 77}
 78
 79func WithLanguageModelGenerateIDFunc(fn LanguageModelGenerateIDFunc) LanguageModelOption {
 80	return func(l *languageModel) {
 81		l.generateIDFunc = fn
 82	}
 83}
 84
 85func newLanguageModel(modelID string, provider string, client openai.Client, opts ...LanguageModelOption) languageModel {
 86	model := languageModel{
 87		modelID:                    modelID,
 88		provider:                   provider,
 89		client:                     client,
 90		generateIDFunc:             defaultGenerateID,
 91		prepareCallFunc:            defaultPrepareLanguageModelCall,
 92		mapFinishReasonFunc:        defaultMapFinishReason,
 93		usageFunc:                  defaultUsage,
 94		streamUsageFunc:            defaultStreamUsage,
 95		streamProviderMetadataFunc: defaultStreamProviderMetadataFunc,
 96	}
 97
 98	for _, o := range opts {
 99		o(&model)
100	}
101	return model
102}
103
104type streamToolCall struct {
105	id          string
106	name        string
107	arguments   string
108	hasFinished bool
109}
110
111// Model implements ai.LanguageModel.
112func (o languageModel) Model() string {
113	return o.modelID
114}
115
116// Provider implements ai.LanguageModel.
117func (o languageModel) Provider() string {
118	return o.provider
119}
120
121func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
122	params := &openai.ChatCompletionNewParams{}
123	messages, warnings := toPrompt(call.Prompt)
124	if call.TopK != nil {
125		warnings = append(warnings, ai.CallWarning{
126			Type:    ai.CallWarningTypeUnsupportedSetting,
127			Setting: "top_k",
128		})
129	}
130
131	if call.MaxOutputTokens != nil {
132		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
133	}
134	if call.Temperature != nil {
135		params.Temperature = param.NewOpt(*call.Temperature)
136	}
137	if call.TopP != nil {
138		params.TopP = param.NewOpt(*call.TopP)
139	}
140	if call.FrequencyPenalty != nil {
141		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
142	}
143	if call.PresencePenalty != nil {
144		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
145	}
146
147	if isReasoningModel(o.modelID) {
148		// remove unsupported settings for reasoning models
149		// see https://platform.openai.com/docs/guides/reasoning#limitations
150		if call.Temperature != nil {
151			params.Temperature = param.Opt[float64]{}
152			warnings = append(warnings, ai.CallWarning{
153				Type:    ai.CallWarningTypeUnsupportedSetting,
154				Setting: "temperature",
155				Details: "temperature is not supported for reasoning models",
156			})
157		}
158		if call.TopP != nil {
159			params.TopP = param.Opt[float64]{}
160			warnings = append(warnings, ai.CallWarning{
161				Type:    ai.CallWarningTypeUnsupportedSetting,
162				Setting: "TopP",
163				Details: "TopP is not supported for reasoning models",
164			})
165		}
166		if call.FrequencyPenalty != nil {
167			params.FrequencyPenalty = param.Opt[float64]{}
168			warnings = append(warnings, ai.CallWarning{
169				Type:    ai.CallWarningTypeUnsupportedSetting,
170				Setting: "FrequencyPenalty",
171				Details: "FrequencyPenalty is not supported for reasoning models",
172			})
173		}
174		if call.PresencePenalty != nil {
175			params.PresencePenalty = param.Opt[float64]{}
176			warnings = append(warnings, ai.CallWarning{
177				Type:    ai.CallWarningTypeUnsupportedSetting,
178				Setting: "PresencePenalty",
179				Details: "PresencePenalty is not supported for reasoning models",
180			})
181		}
182
183		// reasoning models use max_completion_tokens instead of max_tokens
184		if call.MaxOutputTokens != nil {
185			if !params.MaxCompletionTokens.Valid() {
186				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
187			}
188			params.MaxTokens = param.Opt[int64]{}
189		}
190	}
191
192	// Handle search preview models
193	if isSearchPreviewModel(o.modelID) {
194		if call.Temperature != nil {
195			params.Temperature = param.Opt[float64]{}
196			warnings = append(warnings, ai.CallWarning{
197				Type:    ai.CallWarningTypeUnsupportedSetting,
198				Setting: "temperature",
199				Details: "temperature is not supported for the search preview models and has been removed.",
200			})
201		}
202	}
203
204	optionsWarnings, err := o.prepareCallFunc(o, params, call)
205	if err != nil {
206		return nil, nil, err
207	}
208
209	if len(optionsWarnings) > 0 {
210		warnings = append(warnings, optionsWarnings...)
211	}
212
213	params.Messages = messages
214	params.Model = o.modelID
215
216	if len(call.Tools) > 0 {
217		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
218		params.Tools = tools
219		if toolChoice != nil {
220			params.ToolChoice = *toolChoice
221		}
222		warnings = append(warnings, toolWarnings...)
223	}
224	return params, warnings, nil
225}
226
227func (o languageModel) handleError(err error) error {
228	var apiErr *openai.Error
229	if errors.As(err, &apiErr) {
230		requestDump := apiErr.DumpRequest(true)
231		responseDump := apiErr.DumpResponse(true)
232		headers := map[string]string{}
233		for k, h := range apiErr.Response.Header {
234			v := h[len(h)-1]
235			headers[strings.ToLower(k)] = v
236		}
237		return ai.NewAPICallError(
238			apiErr.Message,
239			apiErr.Request.URL.String(),
240			string(requestDump),
241			apiErr.StatusCode,
242			headers,
243			string(responseDump),
244			apiErr,
245			false,
246		)
247	}
248	return err
249}
250
251// Generate implements ai.LanguageModel.
252func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
253	params, warnings, err := o.prepareParams(call)
254	if err != nil {
255		return nil, err
256	}
257	response, err := o.client.Chat.Completions.New(ctx, *params)
258	if err != nil {
259		return nil, o.handleError(err)
260	}
261
262	if len(response.Choices) == 0 {
263		return nil, errors.New("no response generated")
264	}
265	choice := response.Choices[0]
266	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
267	text := choice.Message.Content
268	if text != "" {
269		content = append(content, ai.TextContent{
270			Text: text,
271		})
272	}
273	if o.extraContentFunc != nil {
274		extraContent := o.extraContentFunc(choice)
275		content = append(content, extraContent...)
276	}
277	for _, tc := range choice.Message.ToolCalls {
278		toolCallID := tc.ID
279		if toolCallID == "" || o.uniqueToolCallIds {
280			toolCallID = o.generateIDFunc()
281		}
282		content = append(content, ai.ToolCallContent{
283			ProviderExecuted: false, // TODO: update when handling other tools
284			ToolCallID:       toolCallID,
285			ToolName:         tc.Function.Name,
286			Input:            tc.Function.Arguments,
287		})
288	}
289	// Handle annotations/citations
290	for _, annotation := range choice.Message.Annotations {
291		if annotation.Type == "url_citation" {
292			content = append(content, ai.SourceContent{
293				SourceType: ai.SourceTypeURL,
294				ID:         uuid.NewString(),
295				URL:        annotation.URLCitation.URL,
296				Title:      annotation.URLCitation.Title,
297			})
298		}
299	}
300
301	usage, providerMetadata := o.usageFunc(*response)
302
303	return &ai.Response{
304		Content:      content,
305		Usage:        usage,
306		FinishReason: defaultMapFinishReason(choice),
307		ProviderMetadata: ai.ProviderMetadata{
308			Name: providerMetadata,
309		},
310		Warnings: warnings,
311	}, nil
312}
313
314// Stream implements ai.LanguageModel.
315func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
316	params, warnings, err := o.prepareParams(call)
317	if err != nil {
318		return nil, err
319	}
320
321	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
322		IncludeUsage: openai.Bool(true),
323	}
324
325	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
326	isActiveText := false
327	toolCalls := make(map[int64]streamToolCall)
328
329	// Build provider metadata for streaming
330	providerMetadata := ai.ProviderMetadata{
331		Name: &ProviderMetadata{},
332	}
333	acc := openai.ChatCompletionAccumulator{}
334	extraContext := make(map[string]any)
335	var usage ai.Usage
336	return func(yield func(ai.StreamPart) bool) {
337		if len(warnings) > 0 {
338			if !yield(ai.StreamPart{
339				Type:     ai.StreamPartTypeWarnings,
340				Warnings: warnings,
341			}) {
342				return
343			}
344		}
345		for stream.Next() {
346			chunk := stream.Current()
347			acc.AddChunk(chunk)
348			usage, providerMetadata = o.streamUsageFunc(chunk, extraContext, providerMetadata)
349			if len(chunk.Choices) == 0 {
350				continue
351			}
352			for _, choice := range chunk.Choices {
353				switch {
354				case choice.Delta.Content != "":
355					if !isActiveText {
356						isActiveText = true
357						if !yield(ai.StreamPart{
358							Type: ai.StreamPartTypeTextStart,
359							ID:   "0",
360						}) {
361							return
362						}
363					}
364					if !yield(ai.StreamPart{
365						Type:  ai.StreamPartTypeTextDelta,
366						ID:    "0",
367						Delta: choice.Delta.Content,
368					}) {
369						return
370					}
371				case len(choice.Delta.ToolCalls) > 0:
372					if isActiveText {
373						isActiveText = false
374						if !yield(ai.StreamPart{
375							Type: ai.StreamPartTypeTextEnd,
376							ID:   "0",
377						}) {
378							return
379						}
380					}
381
382					for _, toolCallDelta := range choice.Delta.ToolCalls {
383						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
384							if existingToolCall.hasFinished {
385								continue
386							}
387							if toolCallDelta.Function.Arguments != "" {
388								existingToolCall.arguments += toolCallDelta.Function.Arguments
389							}
390							if !yield(ai.StreamPart{
391								Type:  ai.StreamPartTypeToolInputDelta,
392								ID:    existingToolCall.id,
393								Delta: toolCallDelta.Function.Arguments,
394							}) {
395								return
396							}
397							toolCalls[toolCallDelta.Index] = existingToolCall
398							if xjson.IsValid(existingToolCall.arguments) {
399								if !yield(ai.StreamPart{
400									Type: ai.StreamPartTypeToolInputEnd,
401									ID:   existingToolCall.id,
402								}) {
403									return
404								}
405
406								if !yield(ai.StreamPart{
407									Type:          ai.StreamPartTypeToolCall,
408									ID:            existingToolCall.id,
409									ToolCallName:  existingToolCall.name,
410									ToolCallInput: existingToolCall.arguments,
411								}) {
412									return
413								}
414								existingToolCall.hasFinished = true
415								toolCalls[toolCallDelta.Index] = existingToolCall
416							}
417						} else {
418							// Does not exist
419							var err error
420							if toolCallDelta.Type != "function" {
421								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
422							}
423							if toolCallDelta.ID == "" {
424								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
425							}
426							if toolCallDelta.Function.Name == "" {
427								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
428							}
429							if err != nil {
430								yield(ai.StreamPart{
431									Type:  ai.StreamPartTypeError,
432									Error: o.handleError(stream.Err()),
433								})
434								return
435							}
436
437							// some providers do not send this as a unique id
438							// for some usecases in crush we need this ID to be unique.
439							// it won't change the behavior on the provider side because the
440							// provider only cares about the tool call id matching the result
441							// and in our case that will still be the case
442							if o.uniqueToolCallIds {
443								toolCallDelta.ID = o.generateIDFunc()
444							}
445
446							if !yield(ai.StreamPart{
447								Type:         ai.StreamPartTypeToolInputStart,
448								ID:           toolCallDelta.ID,
449								ToolCallName: toolCallDelta.Function.Name,
450							}) {
451								return
452							}
453							toolCalls[toolCallDelta.Index] = streamToolCall{
454								id:        toolCallDelta.ID,
455								name:      toolCallDelta.Function.Name,
456								arguments: toolCallDelta.Function.Arguments,
457							}
458
459							exTc := toolCalls[toolCallDelta.Index]
460							if exTc.arguments != "" {
461								if !yield(ai.StreamPart{
462									Type:  ai.StreamPartTypeToolInputDelta,
463									ID:    exTc.id,
464									Delta: exTc.arguments,
465								}) {
466									return
467								}
468								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
469									if !yield(ai.StreamPart{
470										Type: ai.StreamPartTypeToolInputEnd,
471										ID:   toolCallDelta.ID,
472									}) {
473										return
474									}
475
476									if !yield(ai.StreamPart{
477										Type:          ai.StreamPartTypeToolCall,
478										ID:            exTc.id,
479										ToolCallName:  exTc.name,
480										ToolCallInput: exTc.arguments,
481									}) {
482										return
483									}
484									exTc.hasFinished = true
485									toolCalls[toolCallDelta.Index] = exTc
486								}
487							}
488							continue
489						}
490					}
491				}
492
493				if o.streamExtraFunc != nil {
494					updatedContext, shouldContinue := o.streamExtraFunc(chunk, yield, extraContext)
495					if !shouldContinue {
496						return
497					}
498					extraContext = updatedContext
499				}
500			}
501
502			// Check for annotations in the delta's raw JSON
503			for _, choice := range chunk.Choices {
504				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
505					for _, annotation := range annotations {
506						if annotation.Type == "url_citation" {
507							if !yield(ai.StreamPart{
508								Type:       ai.StreamPartTypeSource,
509								ID:         uuid.NewString(),
510								SourceType: ai.SourceTypeURL,
511								URL:        annotation.URLCitation.URL,
512								Title:      annotation.URLCitation.Title,
513							}) {
514								return
515							}
516						}
517					}
518				}
519			}
520		}
521		err := stream.Err()
522		if err == nil || errors.Is(err, io.EOF) {
523			// finished
524			if isActiveText {
525				isActiveText = false
526				if !yield(ai.StreamPart{
527					Type: ai.StreamPartTypeTextEnd,
528					ID:   "0",
529				}) {
530					return
531				}
532			}
533
534			if len(acc.Choices) > 0 {
535				choice := acc.Choices[0]
536				// Add logprobs if available
537				providerMetadata = o.streamProviderMetadataFunc(choice, providerMetadata)
538
539				// Handle annotations/citations from accumulated response
540				for _, annotation := range choice.Message.Annotations {
541					if annotation.Type == "url_citation" {
542						if !yield(ai.StreamPart{
543							Type:       ai.StreamPartTypeSource,
544							ID:         acc.ID,
545							SourceType: ai.SourceTypeURL,
546							URL:        annotation.URLCitation.URL,
547							Title:      annotation.URLCitation.Title,
548						}) {
549							return
550						}
551					}
552				}
553			}
554			finishReason := ai.FinishReasonUnknown
555			if len(acc.Choices) > 0 {
556				finishReason = o.mapFinishReasonFunc(acc.Choices[0])
557			}
558			yield(ai.StreamPart{
559				Type:             ai.StreamPartTypeFinish,
560				Usage:            usage,
561				FinishReason:     finishReason,
562				ProviderMetadata: providerMetadata,
563			})
564			return
565		} else {
566			yield(ai.StreamPart{
567				Type:  ai.StreamPartTypeError,
568				Error: o.handleError(err),
569			})
570			return
571		}
572	}, nil
573}
574
575func isReasoningModel(modelID string) bool {
576	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
577}
578
579func isSearchPreviewModel(modelID string) bool {
580	return strings.Contains(modelID, "search-preview")
581}
582
583func supportsFlexProcessing(modelID string) bool {
584	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
585}
586
587func supportsPriorityProcessing(modelID string) bool {
588	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
589		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
590		strings.HasPrefix(modelID, "o4-mini")
591}
592
593func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
594	for _, tool := range tools {
595		if tool.GetType() == ai.ToolTypeFunction {
596			ft, ok := tool.(ai.FunctionTool)
597			if !ok {
598				continue
599			}
600			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
601				OfFunction: &openai.ChatCompletionFunctionToolParam{
602					Function: shared.FunctionDefinitionParam{
603						Name:        ft.Name,
604						Description: param.NewOpt(ft.Description),
605						Parameters:  openai.FunctionParameters(ft.InputSchema),
606						Strict:      param.NewOpt(false),
607					},
608					Type: "function",
609				},
610			})
611			continue
612		}
613
614		// TODO: handle provider tool calls
615		warnings = append(warnings, ai.CallWarning{
616			Type:    ai.CallWarningTypeUnsupportedTool,
617			Tool:    tool,
618			Message: "tool is not supported",
619		})
620	}
621	if toolChoice == nil {
622		return openAiTools, openAiToolChoice, warnings
623	}
624
625	switch *toolChoice {
626	case ai.ToolChoiceAuto:
627		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
628			OfAuto: param.NewOpt("auto"),
629		}
630	case ai.ToolChoiceNone:
631		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
632			OfAuto: param.NewOpt("none"),
633		}
634	default:
635		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
636			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
637				Type: "function",
638				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
639					Name: string(*toolChoice),
640				},
641			},
642		}
643	}
644	return openAiTools, openAiToolChoice, warnings
645}
646
647func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
648	var messages []openai.ChatCompletionMessageParamUnion
649	var warnings []ai.CallWarning
650	for _, msg := range prompt {
651		switch msg.Role {
652		case ai.MessageRoleSystem:
653			var systemPromptParts []string
654			for _, c := range msg.Content {
655				if c.GetType() != ai.ContentTypeText {
656					warnings = append(warnings, ai.CallWarning{
657						Type:    ai.CallWarningTypeOther,
658						Message: "system prompt can only have text content",
659					})
660					continue
661				}
662				textPart, ok := ai.AsContentType[ai.TextPart](c)
663				if !ok {
664					warnings = append(warnings, ai.CallWarning{
665						Type:    ai.CallWarningTypeOther,
666						Message: "system prompt text part does not have the right type",
667					})
668					continue
669				}
670				text := textPart.Text
671				if strings.TrimSpace(text) != "" {
672					systemPromptParts = append(systemPromptParts, textPart.Text)
673				}
674			}
675			if len(systemPromptParts) == 0 {
676				warnings = append(warnings, ai.CallWarning{
677					Type:    ai.CallWarningTypeOther,
678					Message: "system prompt has no text parts",
679				})
680				continue
681			}
682			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
683		case ai.MessageRoleUser:
684			// simple user message just text content
685			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
686				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
687				if !ok {
688					warnings = append(warnings, ai.CallWarning{
689						Type:    ai.CallWarningTypeOther,
690						Message: "user message text part does not have the right type",
691					})
692					continue
693				}
694				messages = append(messages, openai.UserMessage(textPart.Text))
695				continue
696			}
697			// text content and attachments
698			// for now we only support image content later we need to check
699			// TODO: add the supported media types to the language model so we
700			//  can use that to validate the data here.
701			var content []openai.ChatCompletionContentPartUnionParam
702			for _, c := range msg.Content {
703				switch c.GetType() {
704				case ai.ContentTypeText:
705					textPart, ok := ai.AsContentType[ai.TextPart](c)
706					if !ok {
707						warnings = append(warnings, ai.CallWarning{
708							Type:    ai.CallWarningTypeOther,
709							Message: "user message text part does not have the right type",
710						})
711						continue
712					}
713					content = append(content, openai.ChatCompletionContentPartUnionParam{
714						OfText: &openai.ChatCompletionContentPartTextParam{
715							Text: textPart.Text,
716						},
717					})
718				case ai.ContentTypeFile:
719					filePart, ok := ai.AsContentType[ai.FilePart](c)
720					if !ok {
721						warnings = append(warnings, ai.CallWarning{
722							Type:    ai.CallWarningTypeOther,
723							Message: "user message file part does not have the right type",
724						})
725						continue
726					}
727
728					switch {
729					case strings.HasPrefix(filePart.MediaType, "image/"):
730						// Handle image files
731						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
732						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
733						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
734
735						// Check for provider-specific options like image detail
736						if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
737							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
738								imageURL.Detail = detail.ImageDetail
739							}
740						}
741
742						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
743						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
744
745					case filePart.MediaType == "audio/wav":
746						// Handle WAV audio files
747						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
748						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
749							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
750								Data:   base64Encoded,
751								Format: "wav",
752							},
753						}
754						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
755
756					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
757						// Handle MP3 audio files
758						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
759						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
760							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
761								Data:   base64Encoded,
762								Format: "mp3",
763							},
764						}
765						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
766
767					case filePart.MediaType == "application/pdf":
768						// Handle PDF files
769						dataStr := string(filePart.Data)
770
771						// Check if data looks like a file ID (starts with "file-")
772						if strings.HasPrefix(dataStr, "file-") {
773							fileBlock := openai.ChatCompletionContentPartFileParam{
774								File: openai.ChatCompletionContentPartFileFileParam{
775									FileID: param.NewOpt(dataStr),
776								},
777							}
778							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
779						} else {
780							// Handle as base64 data
781							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
782							data := "data:application/pdf;base64," + base64Encoded
783
784							filename := filePart.Filename
785							if filename == "" {
786								// Generate default filename based on content index
787								filename = fmt.Sprintf("part-%d.pdf", len(content))
788							}
789
790							fileBlock := openai.ChatCompletionContentPartFileParam{
791								File: openai.ChatCompletionContentPartFileFileParam{
792									Filename: param.NewOpt(filename),
793									FileData: param.NewOpt(data),
794								},
795							}
796							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
797						}
798
799					default:
800						warnings = append(warnings, ai.CallWarning{
801							Type:    ai.CallWarningTypeOther,
802							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
803						})
804					}
805				}
806			}
807			messages = append(messages, openai.UserMessage(content))
808		case ai.MessageRoleAssistant:
809			// simple assistant message just text content
810			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
811				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
812				if !ok {
813					warnings = append(warnings, ai.CallWarning{
814						Type:    ai.CallWarningTypeOther,
815						Message: "assistant message text part does not have the right type",
816					})
817					continue
818				}
819				messages = append(messages, openai.AssistantMessage(textPart.Text))
820				continue
821			}
822			assistantMsg := openai.ChatCompletionAssistantMessageParam{
823				Role: "assistant",
824			}
825			for _, c := range msg.Content {
826				switch c.GetType() {
827				case ai.ContentTypeText:
828					textPart, ok := ai.AsContentType[ai.TextPart](c)
829					if !ok {
830						warnings = append(warnings, ai.CallWarning{
831							Type:    ai.CallWarningTypeOther,
832							Message: "assistant message text part does not have the right type",
833						})
834						continue
835					}
836					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
837						OfString: param.NewOpt(textPart.Text),
838					}
839				case ai.ContentTypeToolCall:
840					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
841					if !ok {
842						warnings = append(warnings, ai.CallWarning{
843							Type:    ai.CallWarningTypeOther,
844							Message: "assistant message tool part does not have the right type",
845						})
846						continue
847					}
848					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
849						openai.ChatCompletionMessageToolCallUnionParam{
850							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
851								ID:   toolCallPart.ToolCallID,
852								Type: "function",
853								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
854									Name:      toolCallPart.ToolName,
855									Arguments: toolCallPart.Input,
856								},
857							},
858						})
859				}
860			}
861			messages = append(messages, openai.ChatCompletionMessageParamUnion{
862				OfAssistant: &assistantMsg,
863			})
864		case ai.MessageRoleTool:
865			for _, c := range msg.Content {
866				if c.GetType() != ai.ContentTypeToolResult {
867					warnings = append(warnings, ai.CallWarning{
868						Type:    ai.CallWarningTypeOther,
869						Message: "tool message can only have tool result content",
870					})
871					continue
872				}
873
874				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
875				if !ok {
876					warnings = append(warnings, ai.CallWarning{
877						Type:    ai.CallWarningTypeOther,
878						Message: "tool message result part does not have the right type",
879					})
880					continue
881				}
882
883				switch toolResultPart.Output.GetType() {
884				case ai.ToolResultContentTypeText:
885					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
886					if !ok {
887						warnings = append(warnings, ai.CallWarning{
888							Type:    ai.CallWarningTypeOther,
889							Message: "tool result output does not have the right type",
890						})
891						continue
892					}
893					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
894				case ai.ToolResultContentTypeError:
895					// TODO: check if better handling is needed
896					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
897					if !ok {
898						warnings = append(warnings, ai.CallWarning{
899							Type:    ai.CallWarningTypeOther,
900							Message: "tool result output does not have the right type",
901						})
902						continue
903					}
904					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
905				}
906			}
907		}
908	}
909	return messages, warnings
910}
911
912// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
913func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
914	var annotations []openai.ChatCompletionMessageAnnotation
915
916	// Parse the raw JSON to extract annotations
917	var deltaData map[string]any
918	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
919		return annotations
920	}
921
922	// Check if annotations exist in the delta
923	if annotationsData, ok := deltaData["annotations"].([]any); ok {
924		for _, annotationData := range annotationsData {
925			if annotationMap, ok := annotationData.(map[string]any); ok {
926				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
927					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
928						annotation := openai.ChatCompletionMessageAnnotation{
929							Type: "url_citation",
930							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
931								URL:   urlCitationData["url"].(string),
932								Title: urlCitationData["title"].(string),
933							},
934						}
935						annotations = append(annotations, annotation)
936					}
937				}
938			}
939		}
940	}
941
942	return annotations
943}