openai.go

   1package providers
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"io"
  10	"maps"
  11	"strings"
  12
  13	"github.com/charmbracelet/crush/internal/ai"
  14	"github.com/charmbracelet/crush/internal/config"
  15	"github.com/charmbracelet/crush/internal/env"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23type ReasoningEffort string
  24
  25const (
  26	ReasoningEffortMinimal ReasoningEffort = "minimal"
  27	ReasoningEffortLow     ReasoningEffort = "low"
  28	ReasoningEffortMedium  ReasoningEffort = "medium"
  29	ReasoningEffortHigh    ReasoningEffort = "high"
  30)
  31
  32type OpenAIProviderOptions struct {
  33	LogitBias           map[string]int64 `json:"logit_bias"`
  34	LogProbs            *bool            `json:"log_probes"`
  35	TopLogProbs         *int64           `json:"top_log_probs"`
  36	ParallelToolCalls   *bool            `json:"parallel_tool_calls"`
  37	User                *string          `json:"user"`
  38	ReasoningEffort     *ReasoningEffort `json:"reasoning_effort"`
  39	MaxCompletionTokens *int64           `json:"max_completion_tokens"`
  40	TextVerbosity       *string          `json:"text_verbosity"`
  41	Prediction          map[string]any   `json:"prediction"`
  42	Store               *bool            `json:"store"`
  43	Metadata            map[string]any   `json:"metadata"`
  44	PromptCacheKey      *string          `json:"prompt_cache_key"`
  45	SafetyIdentifier    *string          `json:"safety_identifier"`
  46	ServiceTier         *string          `json:"service_tier"`
  47	StructuredOutputs   *bool            `json:"structured_outputs"`
  48}
  49
  50type openAIProvider struct {
  51	options openAIProviderOptions
  52}
  53
  54type openAIProviderOptions struct {
  55	baseURL      string
  56	apiKey       string
  57	organization string
  58	project      string
  59	name         string
  60	headers      map[string]string
  61	client       option.HTTPClient
  62	resolver     config.VariableResolver
  63}
  64
  65type OpenAIOption = func(*openAIProviderOptions)
  66
  67func NewOpenAIProvider(opts ...OpenAIOption) ai.Provider {
  68	options := openAIProviderOptions{
  69		headers: map[string]string{},
  70	}
  71	for _, o := range opts {
  72		o(&options)
  73	}
  74
  75	if options.resolver == nil {
  76		// use the default resolver
  77		options.resolver = config.NewShellVariableResolver(env.New())
  78	}
  79	options.apiKey, _ = options.resolver.ResolveValue(options.apiKey)
  80	options.baseURL, _ = options.resolver.ResolveValue(options.baseURL)
  81	if options.baseURL == "" {
  82		options.baseURL = "https://api.openai.com/v1"
  83	}
  84
  85	options.name, _ = options.resolver.ResolveValue(options.name)
  86	if options.name == "" {
  87		options.name = "openai"
  88	}
  89
  90	for k, v := range options.headers {
  91		options.headers[k], _ = options.resolver.ResolveValue(v)
  92	}
  93
  94	options.organization, _ = options.resolver.ResolveValue(options.organization)
  95	if options.organization != "" {
  96		options.headers["OpenAI-Organization"] = options.organization
  97	}
  98
  99	options.project, _ = options.resolver.ResolveValue(options.project)
 100	if options.project != "" {
 101		options.headers["OpenAI-Project"] = options.project
 102	}
 103
 104	return &openAIProvider{
 105		options: options,
 106	}
 107}
 108
 109func WithOpenAIBaseURL(baseURL string) OpenAIOption {
 110	return func(o *openAIProviderOptions) {
 111		o.baseURL = baseURL
 112	}
 113}
 114
 115func WithOpenAIApiKey(apiKey string) OpenAIOption {
 116	return func(o *openAIProviderOptions) {
 117		o.apiKey = apiKey
 118	}
 119}
 120
 121func WithOpenAIOrganization(organization string) OpenAIOption {
 122	return func(o *openAIProviderOptions) {
 123		o.organization = organization
 124	}
 125}
 126
 127func WithOpenAIProject(project string) OpenAIOption {
 128	return func(o *openAIProviderOptions) {
 129		o.project = project
 130	}
 131}
 132
 133func WithOpenAIName(name string) OpenAIOption {
 134	return func(o *openAIProviderOptions) {
 135		o.name = name
 136	}
 137}
 138
 139func WithOpenAIHeaders(headers map[string]string) OpenAIOption {
 140	return func(o *openAIProviderOptions) {
 141		maps.Copy(o.headers, headers)
 142	}
 143}
 144
 145func WithOpenAIHttpClient(client option.HTTPClient) OpenAIOption {
 146	return func(o *openAIProviderOptions) {
 147		o.client = client
 148	}
 149}
 150
 151func WithOpenAIVariableResolver(resolver config.VariableResolver) OpenAIOption {
 152	return func(o *openAIProviderOptions) {
 153		o.resolver = resolver
 154	}
 155}
 156
 157// LanguageModel implements ai.Provider.
 158func (o *openAIProvider) LanguageModel(modelID string) ai.LanguageModel {
 159	openaiClientOptions := []option.RequestOption{}
 160	if o.options.apiKey != "" {
 161		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 162	}
 163	if o.options.baseURL != "" {
 164		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 165	}
 166
 167	for key, value := range o.options.headers {
 168		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 169	}
 170
 171	if o.options.client != nil {
 172		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 173	}
 174
 175	return openAILanguageModel{
 176		modelID:         modelID,
 177		provider:        fmt.Sprintf("%s.chat", o.options.name),
 178		providerOptions: o.options,
 179		client:          openai.NewClient(openaiClientOptions...),
 180	}
 181}
 182
 183type openAILanguageModel struct {
 184	provider        string
 185	modelID         string
 186	client          openai.Client
 187	providerOptions openAIProviderOptions
 188}
 189
 190// Model implements ai.LanguageModel.
 191func (o openAILanguageModel) Model() string {
 192	return o.modelID
 193}
 194
 195// Provider implements ai.LanguageModel.
 196func (o openAILanguageModel) Provider() string {
 197	return o.provider
 198}
 199
 200func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 201	params := &openai.ChatCompletionNewParams{}
 202	messages, warnings := toOpenAIPrompt(call.Prompt)
 203	providerOptions := &OpenAIProviderOptions{}
 204	if v, ok := call.ProviderOptions["openai"]; ok {
 205		err := ai.ParseOptions(v, providerOptions)
 206		if err != nil {
 207			return nil, nil, err
 208		}
 209	}
 210	if call.TopK != nil {
 211		warnings = append(warnings, ai.CallWarning{
 212			Type:    ai.CallWarningTypeUnsupportedSetting,
 213			Setting: "top_k",
 214		})
 215	}
 216	params.Messages = messages
 217	params.Model = o.modelID
 218	if providerOptions.LogitBias != nil {
 219		params.LogitBias = providerOptions.LogitBias
 220	}
 221	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 222		providerOptions.LogProbs = nil
 223	}
 224	if providerOptions.LogProbs != nil {
 225		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 226	}
 227	if providerOptions.TopLogProbs != nil {
 228		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 229	}
 230	if providerOptions.User != nil {
 231		params.User = param.NewOpt(*providerOptions.User)
 232	}
 233	if providerOptions.ParallelToolCalls != nil {
 234		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 235	}
 236
 237	if call.MaxOutputTokens != nil {
 238		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 239	}
 240	if call.Temperature != nil {
 241		params.Temperature = param.NewOpt(*call.Temperature)
 242	}
 243	if call.TopP != nil {
 244		params.TopP = param.NewOpt(*call.TopP)
 245	}
 246	if call.FrequencyPenalty != nil {
 247		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 248	}
 249	if call.PresencePenalty != nil {
 250		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 251	}
 252
 253	if providerOptions.MaxCompletionTokens != nil {
 254		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 255	}
 256
 257	if providerOptions.TextVerbosity != nil {
 258		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 259	}
 260	if providerOptions.Prediction != nil {
 261		// Convert map[string]any to ChatCompletionPredictionContentParam
 262		if content, ok := providerOptions.Prediction["content"]; ok {
 263			if contentStr, ok := content.(string); ok {
 264				params.Prediction = openai.ChatCompletionPredictionContentParam{
 265					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 266						OfString: param.NewOpt(contentStr),
 267					},
 268				}
 269			}
 270		}
 271	}
 272	if providerOptions.Store != nil {
 273		params.Store = param.NewOpt(*providerOptions.Store)
 274	}
 275	if providerOptions.Metadata != nil {
 276		// Convert map[string]any to map[string]string
 277		metadata := make(map[string]string)
 278		for k, v := range providerOptions.Metadata {
 279			if str, ok := v.(string); ok {
 280				metadata[k] = str
 281			}
 282		}
 283		params.Metadata = metadata
 284	}
 285	if providerOptions.PromptCacheKey != nil {
 286		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 287	}
 288	if providerOptions.SafetyIdentifier != nil {
 289		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 290	}
 291	if providerOptions.ServiceTier != nil {
 292		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 293	}
 294
 295	if providerOptions.ReasoningEffort != nil {
 296		switch *providerOptions.ReasoningEffort {
 297		case ReasoningEffortMinimal:
 298			params.ReasoningEffort = shared.ReasoningEffortMinimal
 299		case ReasoningEffortLow:
 300			params.ReasoningEffort = shared.ReasoningEffortLow
 301		case ReasoningEffortMedium:
 302			params.ReasoningEffort = shared.ReasoningEffortMedium
 303		case ReasoningEffortHigh:
 304			params.ReasoningEffort = shared.ReasoningEffortHigh
 305		default:
 306			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 307		}
 308	}
 309
 310	if isReasoningModel(o.modelID) {
 311		// remove unsupported settings for reasoning models
 312		// see https://platform.openai.com/docs/guides/reasoning#limitations
 313		if call.Temperature != nil {
 314			params.Temperature = param.Opt[float64]{}
 315			warnings = append(warnings, ai.CallWarning{
 316				Type:    ai.CallWarningTypeUnsupportedSetting,
 317				Setting: "temperature",
 318				Details: "temperature is not supported for reasoning models",
 319			})
 320		}
 321		if call.TopP != nil {
 322			params.TopP = param.Opt[float64]{}
 323			warnings = append(warnings, ai.CallWarning{
 324				Type:    ai.CallWarningTypeUnsupportedSetting,
 325				Setting: "top_p",
 326				Details: "topP is not supported for reasoning models",
 327			})
 328		}
 329		if call.FrequencyPenalty != nil {
 330			params.FrequencyPenalty = param.Opt[float64]{}
 331			warnings = append(warnings, ai.CallWarning{
 332				Type:    ai.CallWarningTypeUnsupportedSetting,
 333				Setting: "frequency_penalty",
 334				Details: "frequencyPenalty is not supported for reasoning models",
 335			})
 336		}
 337		if call.PresencePenalty != nil {
 338			params.PresencePenalty = param.Opt[float64]{}
 339			warnings = append(warnings, ai.CallWarning{
 340				Type:    ai.CallWarningTypeUnsupportedSetting,
 341				Setting: "presence_penalty",
 342				Details: "presencePenalty is not supported for reasoning models",
 343			})
 344		}
 345		if providerOptions.LogitBias != nil {
 346			params.LogitBias = nil
 347			warnings = append(warnings, ai.CallWarning{
 348				Type:    ai.CallWarningTypeOther,
 349				Message: "logitBias is not supported for reasoning models",
 350			})
 351		}
 352		if providerOptions.LogProbs != nil {
 353			params.Logprobs = param.Opt[bool]{}
 354			warnings = append(warnings, ai.CallWarning{
 355				Type:    ai.CallWarningTypeOther,
 356				Message: "logprobs is not supported for reasoning models",
 357			})
 358		}
 359		if providerOptions.TopLogProbs != nil {
 360			params.TopLogprobs = param.Opt[int64]{}
 361			warnings = append(warnings, ai.CallWarning{
 362				Type:    ai.CallWarningTypeOther,
 363				Message: "topLogprobs is not supported for reasoning models",
 364			})
 365		}
 366
 367		// reasoning models use max_completion_tokens instead of max_tokens
 368		if call.MaxOutputTokens != nil {
 369			if providerOptions.MaxCompletionTokens == nil {
 370				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 371			}
 372			params.MaxTokens = param.Opt[int64]{}
 373		}
 374	}
 375
 376	// Handle search preview models
 377	if isSearchPreviewModel(o.modelID) {
 378		if call.Temperature != nil {
 379			params.Temperature = param.Opt[float64]{}
 380			warnings = append(warnings, ai.CallWarning{
 381				Type:    ai.CallWarningTypeUnsupportedSetting,
 382				Setting: "temperature",
 383				Details: "temperature is not supported for the search preview models and has been removed.",
 384			})
 385		}
 386	}
 387
 388	// Handle service tier validation
 389	if providerOptions.ServiceTier != nil {
 390		serviceTier := *providerOptions.ServiceTier
 391		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 392			params.ServiceTier = ""
 393			warnings = append(warnings, ai.CallWarning{
 394				Type:    ai.CallWarningTypeUnsupportedSetting,
 395				Setting: "serviceTier",
 396				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 397			})
 398		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 399			params.ServiceTier = ""
 400			warnings = append(warnings, ai.CallWarning{
 401				Type:    ai.CallWarningTypeUnsupportedSetting,
 402				Setting: "serviceTier",
 403				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 404			})
 405		}
 406	}
 407
 408	if len(call.Tools) > 0 {
 409		tools, toolChoice, toolWarnings := toOpenAITools(call.Tools, call.ToolChoice)
 410		params.Tools = tools
 411		if toolChoice != nil {
 412			params.ToolChoice = *toolChoice
 413		}
 414		warnings = append(warnings, toolWarnings...)
 415	}
 416	return params, warnings, nil
 417}
 418
 419// Generate implements ai.LanguageModel.
 420func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 421	params, warnings, err := o.prepareParams(call)
 422	if err != nil {
 423		return nil, err
 424	}
 425	response, err := o.client.Chat.Completions.New(ctx, *params)
 426	if err != nil {
 427		return nil, err
 428	}
 429
 430	if len(response.Choices) == 0 {
 431		return nil, errors.New("no response generated")
 432	}
 433	choice := response.Choices[0]
 434	var content []ai.Content
 435	text := choice.Message.Content
 436	if text != "" {
 437		content = append(content, ai.TextContent{
 438			Text: text,
 439		})
 440	}
 441
 442	for _, tc := range choice.Message.ToolCalls {
 443		toolCallID := tc.ID
 444		if toolCallID == "" {
 445			toolCallID = uuid.NewString()
 446		}
 447		content = append(content, ai.ToolCallContent{
 448			ProviderExecuted: false, // TODO: update when handling other tools
 449			ToolCallID:       toolCallID,
 450			ToolName:         tc.Function.Name,
 451			Input:            tc.Function.Arguments,
 452		})
 453	}
 454	// Handle annotations/citations
 455	for _, annotation := range choice.Message.Annotations {
 456		if annotation.Type == "url_citation" {
 457			content = append(content, ai.SourceContent{
 458				SourceType: ai.SourceTypeURL,
 459				ID:         uuid.NewString(),
 460				URL:        annotation.URLCitation.URL,
 461				Title:      annotation.URLCitation.Title,
 462			})
 463		}
 464	}
 465
 466	completionTokenDetails := response.Usage.CompletionTokensDetails
 467	promptTokenDetails := response.Usage.PromptTokensDetails
 468
 469	// Build provider metadata
 470	providerMetadata := ai.ProviderMetadata{
 471		"openai": make(map[string]any),
 472	}
 473
 474	// Add logprobs if available
 475	if len(choice.Logprobs.Content) > 0 {
 476		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
 477	}
 478
 479	// Add prediction tokens if available
 480	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 481		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 482			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 483		}
 484		if completionTokenDetails.RejectedPredictionTokens > 0 {
 485			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 486		}
 487	}
 488
 489	return &ai.Response{
 490		Content: content,
 491		Usage: ai.Usage{
 492			InputTokens:     response.Usage.PromptTokens,
 493			OutputTokens:    response.Usage.CompletionTokens,
 494			TotalTokens:     response.Usage.TotalTokens,
 495			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 496			CacheReadTokens: promptTokenDetails.CachedTokens,
 497		},
 498		FinishReason:     mapOpenAIFinishReason(choice.FinishReason),
 499		ProviderMetadata: providerMetadata,
 500		Warnings:         warnings,
 501	}, nil
 502}
 503
 504type toolCall struct {
 505	id          string
 506	name        string
 507	arguments   string
 508	hasFinished bool
 509}
 510
 511// Stream implements ai.LanguageModel.
 512func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 513	params, warnings, err := o.prepareParams(call)
 514	if err != nil {
 515		return nil, err
 516	}
 517
 518	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 519		IncludeUsage: openai.Bool(true),
 520	}
 521
 522	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 523	isActiveText := false
 524	toolCalls := make(map[int64]toolCall)
 525
 526	// Build provider metadata for streaming
 527	streamProviderMetadata := ai.ProviderOptions{
 528		"openai": make(map[string]any),
 529	}
 530
 531	acc := openai.ChatCompletionAccumulator{}
 532	var usage ai.Usage
 533	return func(yield func(ai.StreamPart) bool) {
 534		if len(warnings) > 0 {
 535			if !yield(ai.StreamPart{
 536				Type:     ai.StreamPartTypeWarnings,
 537				Warnings: warnings,
 538			}) {
 539				return
 540			}
 541		}
 542		for stream.Next() {
 543			chunk := stream.Current()
 544			acc.AddChunk(chunk)
 545			if chunk.Usage.TotalTokens > 0 {
 546				// we do this here because the acc does not add prompt details
 547				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 548				promptTokenDetails := chunk.Usage.PromptTokensDetails
 549				usage = ai.Usage{
 550					InputTokens:     chunk.Usage.PromptTokens,
 551					OutputTokens:    chunk.Usage.CompletionTokens,
 552					TotalTokens:     chunk.Usage.TotalTokens,
 553					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 554					CacheReadTokens: promptTokenDetails.CachedTokens,
 555				}
 556
 557				// Add prediction tokens if available
 558				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 559					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 560						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 561					}
 562					if completionTokenDetails.RejectedPredictionTokens > 0 {
 563						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 564					}
 565				}
 566			}
 567			if len(chunk.Choices) == 0 {
 568				continue
 569			}
 570			for _, choice := range chunk.Choices {
 571				switch {
 572				case choice.Delta.Content != "":
 573					if !isActiveText {
 574						isActiveText = true
 575						if !yield(ai.StreamPart{
 576							Type: ai.StreamPartTypeTextStart,
 577							ID:   "0",
 578						}) {
 579							return
 580						}
 581					}
 582					if !yield(ai.StreamPart{
 583						Type:  ai.StreamPartTypeTextDelta,
 584						ID:    "0",
 585						Delta: choice.Delta.Content,
 586					}) {
 587						return
 588					}
 589				case len(choice.Delta.ToolCalls) > 0:
 590					if isActiveText {
 591						isActiveText = false
 592						if !yield(ai.StreamPart{
 593							Type: ai.StreamPartTypeTextEnd,
 594							ID:   "0",
 595						}) {
 596							return
 597						}
 598					}
 599
 600					for _, toolCallDelta := range choice.Delta.ToolCalls {
 601						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 602							if existingToolCall.hasFinished {
 603								continue
 604							}
 605							if toolCallDelta.Function.Arguments != "" {
 606								existingToolCall.arguments += toolCallDelta.Function.Arguments
 607							}
 608							if !yield(ai.StreamPart{
 609								Type:  ai.StreamPartTypeToolInputDelta,
 610								ID:    existingToolCall.id,
 611								Delta: toolCallDelta.Function.Arguments,
 612							}) {
 613								return
 614							}
 615							toolCalls[toolCallDelta.Index] = existingToolCall
 616							if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
 617								if !yield(ai.StreamPart{
 618									Type: ai.StreamPartTypeToolInputEnd,
 619									ID:   existingToolCall.id,
 620								}) {
 621									return
 622								}
 623
 624								if !yield(ai.StreamPart{
 625									Type:          ai.StreamPartTypeToolCall,
 626									ID:            existingToolCall.id,
 627									ToolCallName:  existingToolCall.name,
 628									ToolCallInput: existingToolCall.arguments,
 629								}) {
 630									return
 631								}
 632								existingToolCall.hasFinished = true
 633								toolCalls[toolCallDelta.Index] = existingToolCall
 634							}
 635
 636						} else {
 637							// Does not exist
 638							var err error
 639							if toolCallDelta.Type != "function" {
 640								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 641							}
 642							if toolCallDelta.ID == "" {
 643								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 644							}
 645							if toolCallDelta.Function.Name == "" {
 646								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 647							}
 648							if err != nil {
 649								yield(ai.StreamPart{
 650									Type:  ai.StreamPartTypeError,
 651									Error: stream.Err(),
 652								})
 653								return
 654							}
 655
 656							if !yield(ai.StreamPart{
 657								Type:         ai.StreamPartTypeToolInputStart,
 658								ID:           toolCallDelta.ID,
 659								ToolCallName: toolCallDelta.Function.Name,
 660							}) {
 661								return
 662							}
 663							toolCalls[toolCallDelta.Index] = toolCall{
 664								id:        toolCallDelta.ID,
 665								name:      toolCallDelta.Function.Name,
 666								arguments: toolCallDelta.Function.Arguments,
 667							}
 668
 669							exTc := toolCalls[toolCallDelta.Index]
 670							if exTc.arguments != "" {
 671								if !yield(ai.StreamPart{
 672									Type:  ai.StreamPartTypeToolInputDelta,
 673									ID:    exTc.id,
 674									Delta: exTc.arguments,
 675								}) {
 676									return
 677								}
 678								if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
 679									if !yield(ai.StreamPart{
 680										Type: ai.StreamPartTypeToolInputEnd,
 681										ID:   toolCallDelta.ID,
 682									}) {
 683										return
 684									}
 685
 686									if !yield(ai.StreamPart{
 687										Type:          ai.StreamPartTypeToolCall,
 688										ID:            exTc.id,
 689										ToolCallName:  exTc.name,
 690										ToolCallInput: exTc.arguments,
 691									}) {
 692										return
 693									}
 694									exTc.hasFinished = true
 695									toolCalls[toolCallDelta.Index] = exTc
 696								}
 697							}
 698							continue
 699						}
 700					}
 701				}
 702			}
 703
 704			// Check for annotations in the delta's raw JSON
 705			for _, choice := range chunk.Choices {
 706				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 707					for _, annotation := range annotations {
 708						if annotation.Type == "url_citation" {
 709							if !yield(ai.StreamPart{
 710								Type:       ai.StreamPartTypeSource,
 711								ID:         uuid.NewString(),
 712								SourceType: ai.SourceTypeURL,
 713								URL:        annotation.URLCitation.URL,
 714								Title:      annotation.URLCitation.Title,
 715							}) {
 716								return
 717							}
 718						}
 719					}
 720				}
 721			}
 722
 723		}
 724		err := stream.Err()
 725		if err == nil || errors.Is(err, io.EOF) {
 726			// finished
 727			if isActiveText {
 728				isActiveText = false
 729				if !yield(ai.StreamPart{
 730					Type: ai.StreamPartTypeTextEnd,
 731					ID:   "0",
 732				}) {
 733					return
 734				}
 735			}
 736
 737			// Add logprobs if available
 738			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 739				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
 740			}
 741
 742			// Handle annotations/citations from accumulated response
 743			if len(acc.Choices) > 0 {
 744				for _, annotation := range acc.Choices[0].Message.Annotations {
 745					if annotation.Type == "url_citation" {
 746						if !yield(ai.StreamPart{
 747							Type:       ai.StreamPartTypeSource,
 748							ID:         uuid.NewString(),
 749							SourceType: ai.SourceTypeURL,
 750							URL:        annotation.URLCitation.URL,
 751							Title:      annotation.URLCitation.Title,
 752						}) {
 753							return
 754						}
 755					}
 756				}
 757			}
 758
 759			finishReason := mapOpenAIFinishReason(acc.Choices[0].FinishReason)
 760			yield(ai.StreamPart{
 761				Type:             ai.StreamPartTypeFinish,
 762				Usage:            usage,
 763				FinishReason:     finishReason,
 764				ProviderMetadata: streamProviderMetadata,
 765			})
 766			return
 767
 768		} else {
 769			yield(ai.StreamPart{
 770				Type:  ai.StreamPartTypeError,
 771				Error: stream.Err(),
 772			})
 773			return
 774		}
 775	}, nil
 776}
 777
 778func mapOpenAIFinishReason(finishReason string) ai.FinishReason {
 779	switch finishReason {
 780	case "stop":
 781		return ai.FinishReasonStop
 782	case "length":
 783		return ai.FinishReasonLength
 784	case "content_filter":
 785		return ai.FinishReasonContentFilter
 786	case "function_call", "tool_calls":
 787		return ai.FinishReasonToolCalls
 788	default:
 789		return ai.FinishReasonUnknown
 790	}
 791}
 792
 793func isReasoningModel(modelID string) bool {
 794	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 795}
 796
 797func isSearchPreviewModel(modelID string) bool {
 798	return strings.Contains(modelID, "search-preview")
 799}
 800
 801func supportsFlexProcessing(modelID string) bool {
 802	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 803}
 804
 805func supportsPriorityProcessing(modelID string) bool {
 806	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 807		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 808		strings.HasPrefix(modelID, "o4-mini")
 809}
 810
 811func toOpenAITools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAITools []openai.ChatCompletionToolUnionParam, openAIToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 812	for _, tool := range tools {
 813		if tool.GetType() == ai.ToolTypeFunction {
 814			ft, ok := tool.(ai.FunctionTool)
 815			if !ok {
 816				continue
 817			}
 818			openAITools = append(openAITools, openai.ChatCompletionToolUnionParam{
 819				OfFunction: &openai.ChatCompletionFunctionToolParam{
 820					Function: shared.FunctionDefinitionParam{
 821						Name:        ft.Name,
 822						Description: param.NewOpt(ft.Description),
 823						Parameters:  openai.FunctionParameters(ft.InputSchema),
 824						Strict:      param.NewOpt(false),
 825					},
 826					Type: "function",
 827				},
 828			})
 829			continue
 830		}
 831
 832		// TODO: handle provider tool calls
 833		warnings = append(warnings, ai.CallWarning{
 834			Type:    ai.CallWarningTypeUnsupportedTool,
 835			Tool:    tool,
 836			Message: "tool is not supported",
 837		})
 838	}
 839	if toolChoice == nil {
 840		return
 841	}
 842
 843	switch *toolChoice {
 844	case ai.ToolChoiceAuto:
 845		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 846			OfAuto: param.NewOpt("auto"),
 847		}
 848	case ai.ToolChoiceNone:
 849		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 850			OfAuto: param.NewOpt("none"),
 851		}
 852	default:
 853		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 854			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 855				Type: "function",
 856				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 857					Name: string(*toolChoice),
 858				},
 859			},
 860		}
 861	}
 862	return
 863}
 864
 865func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 866	var messages []openai.ChatCompletionMessageParamUnion
 867	var warnings []ai.CallWarning
 868	for _, msg := range prompt {
 869		switch msg.Role {
 870		case ai.MessageRoleSystem:
 871			var systemPromptParts []string
 872			for _, c := range msg.Content {
 873				if c.GetType() != ai.ContentTypeText {
 874					warnings = append(warnings, ai.CallWarning{
 875						Type:    ai.CallWarningTypeOther,
 876						Message: "system prompt can only have text content",
 877					})
 878					continue
 879				}
 880				textPart, ok := ai.AsContentType[ai.TextPart](c)
 881				if !ok {
 882					warnings = append(warnings, ai.CallWarning{
 883						Type:    ai.CallWarningTypeOther,
 884						Message: "system prompt text part does not have the right type",
 885					})
 886					continue
 887				}
 888				text := textPart.Text
 889				if strings.TrimSpace(text) != "" {
 890					systemPromptParts = append(systemPromptParts, textPart.Text)
 891				}
 892			}
 893			if len(systemPromptParts) == 0 {
 894				warnings = append(warnings, ai.CallWarning{
 895					Type:    ai.CallWarningTypeOther,
 896					Message: "system prompt has no text parts",
 897				})
 898				continue
 899			}
 900			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 901		case ai.MessageRoleUser:
 902			// simple user message just text content
 903			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 904				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 905				if !ok {
 906					warnings = append(warnings, ai.CallWarning{
 907						Type:    ai.CallWarningTypeOther,
 908						Message: "user message text part does not have the right type",
 909					})
 910					continue
 911				}
 912				messages = append(messages, openai.UserMessage(textPart.Text))
 913				continue
 914			}
 915			// text content and attachments
 916			// for now we only support image content later we need to check
 917			// TODO: add the supported media types to the language model so we
 918			//  can use that to validate the data here.
 919			var content []openai.ChatCompletionContentPartUnionParam
 920			for _, c := range msg.Content {
 921				switch c.GetType() {
 922				case ai.ContentTypeText:
 923					textPart, ok := ai.AsContentType[ai.TextPart](c)
 924					if !ok {
 925						warnings = append(warnings, ai.CallWarning{
 926							Type:    ai.CallWarningTypeOther,
 927							Message: "user message text part does not have the right type",
 928						})
 929						continue
 930					}
 931					content = append(content, openai.ChatCompletionContentPartUnionParam{
 932						OfText: &openai.ChatCompletionContentPartTextParam{
 933							Text: textPart.Text,
 934						},
 935					})
 936				case ai.ContentTypeFile:
 937					filePart, ok := ai.AsContentType[ai.FilePart](c)
 938					if !ok {
 939						warnings = append(warnings, ai.CallWarning{
 940							Type:    ai.CallWarningTypeOther,
 941							Message: "user message file part does not have the right type",
 942						})
 943						continue
 944					}
 945
 946					switch {
 947					case strings.HasPrefix(filePart.MediaType, "image/"):
 948						// Handle image files
 949						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 950						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 951						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 952
 953						// Check for provider-specific options like image detail
 954						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 955							if detail, ok := providerOptions["imageDetail"].(string); ok {
 956								imageURL.Detail = detail
 957							}
 958						}
 959
 960						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 961						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 962
 963					case filePart.MediaType == "audio/wav":
 964						// Handle WAV audio files
 965						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 966						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 967							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 968								Data:   base64Encoded,
 969								Format: "wav",
 970							},
 971						}
 972						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 973
 974					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 975						// Handle MP3 audio files
 976						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 977						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 978							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 979								Data:   base64Encoded,
 980								Format: "mp3",
 981							},
 982						}
 983						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 984
 985					case filePart.MediaType == "application/pdf":
 986						// Handle PDF files
 987						dataStr := string(filePart.Data)
 988
 989						// Check if data looks like a file ID (starts with "file-")
 990						if strings.HasPrefix(dataStr, "file-") {
 991							fileBlock := openai.ChatCompletionContentPartFileParam{
 992								File: openai.ChatCompletionContentPartFileFileParam{
 993									FileID: param.NewOpt(dataStr),
 994								},
 995							}
 996							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 997						} else {
 998							// Handle as base64 data
 999							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1000							data := "data:application/pdf;base64," + base64Encoded
1001
1002							filename := filePart.Filename
1003							if filename == "" {
1004								// Generate default filename based on content index
1005								filename = fmt.Sprintf("part-%d.pdf", len(content))
1006							}
1007
1008							fileBlock := openai.ChatCompletionContentPartFileParam{
1009								File: openai.ChatCompletionContentPartFileFileParam{
1010									Filename: param.NewOpt(filename),
1011									FileData: param.NewOpt(data),
1012								},
1013							}
1014							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1015						}
1016
1017					default:
1018						warnings = append(warnings, ai.CallWarning{
1019							Type:    ai.CallWarningTypeOther,
1020							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1021						})
1022					}
1023				}
1024			}
1025			messages = append(messages, openai.UserMessage(content))
1026		case ai.MessageRoleAssistant:
1027			// simple assistant message just text content
1028			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1029				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1030				if !ok {
1031					warnings = append(warnings, ai.CallWarning{
1032						Type:    ai.CallWarningTypeOther,
1033						Message: "assistant message text part does not have the right type",
1034					})
1035					continue
1036				}
1037				messages = append(messages, openai.AssistantMessage(textPart.Text))
1038				continue
1039			}
1040			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1041				Role: "assistant",
1042			}
1043			for _, c := range msg.Content {
1044				switch c.GetType() {
1045				case ai.ContentTypeText:
1046					textPart, ok := ai.AsContentType[ai.TextPart](c)
1047					if !ok {
1048						warnings = append(warnings, ai.CallWarning{
1049							Type:    ai.CallWarningTypeOther,
1050							Message: "assistant message text part does not have the right type",
1051						})
1052						continue
1053					}
1054					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1055						OfString: param.NewOpt(textPart.Text),
1056					}
1057				case ai.ContentTypeToolCall:
1058					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1059					if !ok {
1060						warnings = append(warnings, ai.CallWarning{
1061							Type:    ai.CallWarningTypeOther,
1062							Message: "assistant message tool part does not have the right type",
1063						})
1064						continue
1065					}
1066					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1067						openai.ChatCompletionMessageToolCallUnionParam{
1068							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1069								ID:   toolCallPart.ToolCallID,
1070								Type: "function",
1071								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1072									Name:      toolCallPart.ToolName,
1073									Arguments: toolCallPart.Input,
1074								},
1075							},
1076						})
1077				}
1078			}
1079			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1080				OfAssistant: &assistantMsg,
1081			})
1082		case ai.MessageRoleTool:
1083			for _, c := range msg.Content {
1084				if c.GetType() != ai.ContentTypeToolResult {
1085					warnings = append(warnings, ai.CallWarning{
1086						Type:    ai.CallWarningTypeOther,
1087						Message: "tool message can only have tool result content",
1088					})
1089					continue
1090				}
1091
1092				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1093				if !ok {
1094					warnings = append(warnings, ai.CallWarning{
1095						Type:    ai.CallWarningTypeOther,
1096						Message: "tool message result part does not have the right type",
1097					})
1098					continue
1099				}
1100
1101				switch toolResultPart.Output.GetType() {
1102				case ai.ToolResultContentTypeText:
1103					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1104					if !ok {
1105						warnings = append(warnings, ai.CallWarning{
1106							Type:    ai.CallWarningTypeOther,
1107							Message: "tool result output does not have the right type",
1108						})
1109						continue
1110					}
1111					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1112				case ai.ToolResultContentTypeError:
1113					// TODO: check if better handling is needed
1114					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1115					if !ok {
1116						warnings = append(warnings, ai.CallWarning{
1117							Type:    ai.CallWarningTypeOther,
1118							Message: "tool result output does not have the right type",
1119						})
1120						continue
1121					}
1122					messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID))
1123				}
1124			}
1125		}
1126	}
1127	return messages, warnings
1128}
1129
1130// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
1131func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1132	var annotations []openai.ChatCompletionMessageAnnotation
1133
1134	// Parse the raw JSON to extract annotations
1135	var deltaData map[string]interface{}
1136	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1137		return annotations
1138	}
1139
1140	// Check if annotations exist in the delta
1141	if annotationsData, ok := deltaData["annotations"].([]interface{}); ok {
1142		for _, annotationData := range annotationsData {
1143			if annotationMap, ok := annotationData.(map[string]interface{}); ok {
1144				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1145					if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok {
1146						annotation := openai.ChatCompletionMessageAnnotation{
1147							Type: "url_citation",
1148							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1149								URL:   urlCitationData["url"].(string),
1150								Title: urlCitationData["title"].(string),
1151							},
1152						}
1153						annotations = append(annotations, annotation)
1154					}
1155				}
1156			}
1157		}
1158	}
1159
1160	return annotations
1161}