openai.go

   1package providers
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"io"
  10	"maps"
  11	"strings"
  12
  13	"github.com/charmbracelet/crush/internal/ai"
  14	"github.com/google/uuid"
  15	"github.com/openai/openai-go/v2"
  16	"github.com/openai/openai-go/v2/option"
  17	"github.com/openai/openai-go/v2/packages/param"
  18	"github.com/openai/openai-go/v2/shared"
  19)
  20
  21type ReasoningEffort string
  22
  23const (
  24	ReasoningEffortMinimal ReasoningEffort = "minimal"
  25	ReasoningEffortLow     ReasoningEffort = "low"
  26	ReasoningEffortMedium  ReasoningEffort = "medium"
  27	ReasoningEffortHigh    ReasoningEffort = "high"
  28)
  29
  30type OpenAiProviderOptions struct {
  31	LogitBias           map[string]int64 `json:"logit_bias"`
  32	LogProbs            *bool            `json:"log_probes"`
  33	TopLogProbs         *int64           `json:"top_log_probs"`
  34	ParallelToolCalls   *bool            `json:"parallel_tool_calls"`
  35	User                *string          `json:"user"`
  36	ReasoningEffort     *ReasoningEffort `json:"reasoning_effort"`
  37	MaxCompletionTokens *int64           `json:"max_completion_tokens"`
  38	TextVerbosity       *string          `json:"text_verbosity"`
  39	Prediction          map[string]any   `json:"prediction"`
  40	Store               *bool            `json:"store"`
  41	Metadata            map[string]any   `json:"metadata"`
  42	PromptCacheKey      *string          `json:"prompt_cache_key"`
  43	SafetyIdentifier    *string          `json:"safety_identifier"`
  44	ServiceTier         *string          `json:"service_tier"`
  45	StructuredOutputs   *bool            `json:"structured_outputs"`
  46}
  47
  48type openAiProvider struct {
  49	options openAiProviderOptions
  50}
  51
  52type openAiProviderOptions struct {
  53	baseURL      string
  54	apiKey       string
  55	organization string
  56	project      string
  57	name         string
  58	headers      map[string]string
  59	client       option.HTTPClient
  60}
  61
  62type OpenAiOption = func(*openAiProviderOptions)
  63
  64func NewOpenAiProvider(opts ...OpenAiOption) ai.Provider {
  65	options := openAiProviderOptions{
  66		headers: map[string]string{},
  67	}
  68	for _, o := range opts {
  69		o(&options)
  70	}
  71
  72	if options.baseURL == "" {
  73		options.baseURL = "https://api.openai.com/v1"
  74	}
  75
  76	if options.name == "" {
  77		options.name = "openai"
  78	}
  79
  80	if options.organization != "" {
  81		options.headers["OpenAi-Organization"] = options.organization
  82	}
  83
  84	if options.project != "" {
  85		options.headers["OpenAi-Project"] = options.project
  86	}
  87
  88	return &openAiProvider{
  89		options: options,
  90	}
  91}
  92
  93func WithOpenAiBaseURL(baseURL string) OpenAiOption {
  94	return func(o *openAiProviderOptions) {
  95		o.baseURL = baseURL
  96	}
  97}
  98
  99func WithOpenAiAPIKey(apiKey string) OpenAiOption {
 100	return func(o *openAiProviderOptions) {
 101		o.apiKey = apiKey
 102	}
 103}
 104
 105func WithOpenAiOrganization(organization string) OpenAiOption {
 106	return func(o *openAiProviderOptions) {
 107		o.organization = organization
 108	}
 109}
 110
 111func WithOpenAiProject(project string) OpenAiOption {
 112	return func(o *openAiProviderOptions) {
 113		o.project = project
 114	}
 115}
 116
 117func WithOpenAiName(name string) OpenAiOption {
 118	return func(o *openAiProviderOptions) {
 119		o.name = name
 120	}
 121}
 122
 123func WithOpenAiHeaders(headers map[string]string) OpenAiOption {
 124	return func(o *openAiProviderOptions) {
 125		maps.Copy(o.headers, headers)
 126	}
 127}
 128
 129func WithOpenAiHTTPClient(client option.HTTPClient) OpenAiOption {
 130	return func(o *openAiProviderOptions) {
 131		o.client = client
 132	}
 133}
 134
 135// LanguageModel implements ai.Provider.
 136func (o *openAiProvider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 137	openaiClientOptions := []option.RequestOption{}
 138	if o.options.apiKey != "" {
 139		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 140	}
 141	if o.options.baseURL != "" {
 142		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 143	}
 144
 145	for key, value := range o.options.headers {
 146		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 147	}
 148
 149	if o.options.client != nil {
 150		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 151	}
 152
 153	return openAiLanguageModel{
 154		modelID:         modelID,
 155		provider:        fmt.Sprintf("%s.chat", o.options.name),
 156		providerOptions: o.options,
 157		client:          openai.NewClient(openaiClientOptions...),
 158	}, nil
 159}
 160
 161type openAiLanguageModel struct {
 162	provider        string
 163	modelID         string
 164	client          openai.Client
 165	providerOptions openAiProviderOptions
 166}
 167
 168// Model implements ai.LanguageModel.
 169func (o openAiLanguageModel) Model() string {
 170	return o.modelID
 171}
 172
 173// Provider implements ai.LanguageModel.
 174func (o openAiLanguageModel) Provider() string {
 175	return o.provider
 176}
 177
 178func (o openAiLanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 179	params := &openai.ChatCompletionNewParams{}
 180	messages, warnings := toOpenAiPrompt(call.Prompt)
 181	providerOptions := &OpenAiProviderOptions{}
 182	if v, ok := call.ProviderOptions["openai"]; ok {
 183		err := ai.ParseOptions(v, providerOptions)
 184		if err != nil {
 185			return nil, nil, err
 186		}
 187	}
 188	if call.TopK != nil {
 189		warnings = append(warnings, ai.CallWarning{
 190			Type:    ai.CallWarningTypeUnsupportedSetting,
 191			Setting: "top_k",
 192		})
 193	}
 194	params.Messages = messages
 195	params.Model = o.modelID
 196	if providerOptions.LogitBias != nil {
 197		params.LogitBias = providerOptions.LogitBias
 198	}
 199	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 200		providerOptions.LogProbs = nil
 201	}
 202	if providerOptions.LogProbs != nil {
 203		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 204	}
 205	if providerOptions.TopLogProbs != nil {
 206		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 207	}
 208	if providerOptions.User != nil {
 209		params.User = param.NewOpt(*providerOptions.User)
 210	}
 211	if providerOptions.ParallelToolCalls != nil {
 212		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 213	}
 214
 215	if call.MaxOutputTokens != nil {
 216		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 217	}
 218	if call.Temperature != nil {
 219		params.Temperature = param.NewOpt(*call.Temperature)
 220	}
 221	if call.TopP != nil {
 222		params.TopP = param.NewOpt(*call.TopP)
 223	}
 224	if call.FrequencyPenalty != nil {
 225		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 226	}
 227	if call.PresencePenalty != nil {
 228		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 229	}
 230
 231	if providerOptions.MaxCompletionTokens != nil {
 232		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 233	}
 234
 235	if providerOptions.TextVerbosity != nil {
 236		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 237	}
 238	if providerOptions.Prediction != nil {
 239		// Convert map[string]any to ChatCompletionPredictionContentParam
 240		if content, ok := providerOptions.Prediction["content"]; ok {
 241			if contentStr, ok := content.(string); ok {
 242				params.Prediction = openai.ChatCompletionPredictionContentParam{
 243					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 244						OfString: param.NewOpt(contentStr),
 245					},
 246				}
 247			}
 248		}
 249	}
 250	if providerOptions.Store != nil {
 251		params.Store = param.NewOpt(*providerOptions.Store)
 252	}
 253	if providerOptions.Metadata != nil {
 254		// Convert map[string]any to map[string]string
 255		metadata := make(map[string]string)
 256		for k, v := range providerOptions.Metadata {
 257			if str, ok := v.(string); ok {
 258				metadata[k] = str
 259			}
 260		}
 261		params.Metadata = metadata
 262	}
 263	if providerOptions.PromptCacheKey != nil {
 264		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 265	}
 266	if providerOptions.SafetyIdentifier != nil {
 267		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 268	}
 269	if providerOptions.ServiceTier != nil {
 270		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 271	}
 272
 273	if providerOptions.ReasoningEffort != nil {
 274		switch *providerOptions.ReasoningEffort {
 275		case ReasoningEffortMinimal:
 276			params.ReasoningEffort = shared.ReasoningEffortMinimal
 277		case ReasoningEffortLow:
 278			params.ReasoningEffort = shared.ReasoningEffortLow
 279		case ReasoningEffortMedium:
 280			params.ReasoningEffort = shared.ReasoningEffortMedium
 281		case ReasoningEffortHigh:
 282			params.ReasoningEffort = shared.ReasoningEffortHigh
 283		default:
 284			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 285		}
 286	}
 287
 288	if isReasoningModel(o.modelID) {
 289		// remove unsupported settings for reasoning models
 290		// see https://platform.openai.com/docs/guides/reasoning#limitations
 291		if call.Temperature != nil {
 292			params.Temperature = param.Opt[float64]{}
 293			warnings = append(warnings, ai.CallWarning{
 294				Type:    ai.CallWarningTypeUnsupportedSetting,
 295				Setting: "temperature",
 296				Details: "temperature is not supported for reasoning models",
 297			})
 298		}
 299		if call.TopP != nil {
 300			params.TopP = param.Opt[float64]{}
 301			warnings = append(warnings, ai.CallWarning{
 302				Type:    ai.CallWarningTypeUnsupportedSetting,
 303				Setting: "TopP",
 304				Details: "TopP is not supported for reasoning models",
 305			})
 306		}
 307		if call.FrequencyPenalty != nil {
 308			params.FrequencyPenalty = param.Opt[float64]{}
 309			warnings = append(warnings, ai.CallWarning{
 310				Type:    ai.CallWarningTypeUnsupportedSetting,
 311				Setting: "FrequencyPenalty",
 312				Details: "FrequencyPenalty is not supported for reasoning models",
 313			})
 314		}
 315		if call.PresencePenalty != nil {
 316			params.PresencePenalty = param.Opt[float64]{}
 317			warnings = append(warnings, ai.CallWarning{
 318				Type:    ai.CallWarningTypeUnsupportedSetting,
 319				Setting: "PresencePenalty",
 320				Details: "PresencePenalty is not supported for reasoning models",
 321			})
 322		}
 323		if providerOptions.LogitBias != nil {
 324			params.LogitBias = nil
 325			warnings = append(warnings, ai.CallWarning{
 326				Type:    ai.CallWarningTypeUnsupportedSetting,
 327				Setting: "LogitBias",
 328				Message: "LogitBias is not supported for reasoning models",
 329			})
 330		}
 331		if providerOptions.LogProbs != nil {
 332			params.Logprobs = param.Opt[bool]{}
 333			warnings = append(warnings, ai.CallWarning{
 334				Type:    ai.CallWarningTypeUnsupportedSetting,
 335				Setting: "Logprobs",
 336				Message: "Logprobs is not supported for reasoning models",
 337			})
 338		}
 339		if providerOptions.TopLogProbs != nil {
 340			params.TopLogprobs = param.Opt[int64]{}
 341			warnings = append(warnings, ai.CallWarning{
 342				Type:    ai.CallWarningTypeUnsupportedSetting,
 343				Setting: "TopLogprobs",
 344				Message: "TopLogprobs is not supported for reasoning models",
 345			})
 346		}
 347
 348		// reasoning models use max_completion_tokens instead of max_tokens
 349		if call.MaxOutputTokens != nil {
 350			if providerOptions.MaxCompletionTokens == nil {
 351				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 352			}
 353			params.MaxTokens = param.Opt[int64]{}
 354		}
 355	}
 356
 357	// Handle search preview models
 358	if isSearchPreviewModel(o.modelID) {
 359		if call.Temperature != nil {
 360			params.Temperature = param.Opt[float64]{}
 361			warnings = append(warnings, ai.CallWarning{
 362				Type:    ai.CallWarningTypeUnsupportedSetting,
 363				Setting: "temperature",
 364				Details: "temperature is not supported for the search preview models and has been removed.",
 365			})
 366		}
 367	}
 368
 369	// Handle service tier validation
 370	if providerOptions.ServiceTier != nil {
 371		serviceTier := *providerOptions.ServiceTier
 372		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 373			params.ServiceTier = ""
 374			warnings = append(warnings, ai.CallWarning{
 375				Type:    ai.CallWarningTypeUnsupportedSetting,
 376				Setting: "ServiceTier",
 377				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 378			})
 379		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 380			params.ServiceTier = ""
 381			warnings = append(warnings, ai.CallWarning{
 382				Type:    ai.CallWarningTypeUnsupportedSetting,
 383				Setting: "ServiceTier",
 384				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 385			})
 386		}
 387	}
 388
 389	if len(call.Tools) > 0 {
 390		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 391		params.Tools = tools
 392		if toolChoice != nil {
 393			params.ToolChoice = *toolChoice
 394		}
 395		warnings = append(warnings, toolWarnings...)
 396	}
 397	return params, warnings, nil
 398}
 399
 400func (o openAiLanguageModel) handleError(err error) error {
 401	var apiErr *openai.Error
 402	if errors.As(err, &apiErr) {
 403		requestDump := apiErr.DumpRequest(true)
 404		responseDump := apiErr.DumpResponse(true)
 405		headers := map[string]string{}
 406		for k, h := range apiErr.Response.Header {
 407			v := h[len(h)-1]
 408			headers[strings.ToLower(k)] = v
 409		}
 410		return ai.NewAPICallError(
 411			apiErr.Message,
 412			apiErr.Request.URL.String(),
 413			string(requestDump),
 414			apiErr.StatusCode,
 415			headers,
 416			string(responseDump),
 417			apiErr,
 418			false,
 419		)
 420	}
 421	return err
 422}
 423
 424// Generate implements ai.LanguageModel.
 425func (o openAiLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 426	params, warnings, err := o.prepareParams(call)
 427	if err != nil {
 428		return nil, err
 429	}
 430	response, err := o.client.Chat.Completions.New(ctx, *params)
 431	if err != nil {
 432		return nil, o.handleError(err)
 433	}
 434
 435	if len(response.Choices) == 0 {
 436		return nil, errors.New("no response generated")
 437	}
 438	choice := response.Choices[0]
 439	var content []ai.Content
 440	text := choice.Message.Content
 441	if text != "" {
 442		content = append(content, ai.TextContent{
 443			Text: text,
 444		})
 445	}
 446
 447	for _, tc := range choice.Message.ToolCalls {
 448		toolCallID := tc.ID
 449		if toolCallID == "" {
 450			toolCallID = uuid.NewString()
 451		}
 452		content = append(content, ai.ToolCallContent{
 453			ProviderExecuted: false, // TODO: update when handling other tools
 454			ToolCallID:       toolCallID,
 455			ToolName:         tc.Function.Name,
 456			Input:            tc.Function.Arguments,
 457		})
 458	}
 459	// Handle annotations/citations
 460	for _, annotation := range choice.Message.Annotations {
 461		if annotation.Type == "url_citation" {
 462			content = append(content, ai.SourceContent{
 463				SourceType: ai.SourceTypeURL,
 464				ID:         uuid.NewString(),
 465				URL:        annotation.URLCitation.URL,
 466				Title:      annotation.URLCitation.Title,
 467			})
 468		}
 469	}
 470
 471	completionTokenDetails := response.Usage.CompletionTokensDetails
 472	promptTokenDetails := response.Usage.PromptTokensDetails
 473
 474	// Build provider metadata
 475	providerMetadata := ai.ProviderMetadata{
 476		"openai": make(map[string]any),
 477	}
 478
 479	// Add logprobs if available
 480	if len(choice.Logprobs.Content) > 0 {
 481		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
 482	}
 483
 484	// Add prediction tokens if available
 485	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 486		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 487			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 488		}
 489		if completionTokenDetails.RejectedPredictionTokens > 0 {
 490			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 491		}
 492	}
 493
 494	return &ai.Response{
 495		Content: content,
 496		Usage: ai.Usage{
 497			InputTokens:     response.Usage.PromptTokens,
 498			OutputTokens:    response.Usage.CompletionTokens,
 499			TotalTokens:     response.Usage.TotalTokens,
 500			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 501			CacheReadTokens: promptTokenDetails.CachedTokens,
 502		},
 503		FinishReason:     mapOpenAiFinishReason(choice.FinishReason),
 504		ProviderMetadata: providerMetadata,
 505		Warnings:         warnings,
 506	}, nil
 507}
 508
 509type toolCall struct {
 510	id          string
 511	name        string
 512	arguments   string
 513	hasFinished bool
 514}
 515
 516// Stream implements ai.LanguageModel.
 517func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 518	params, warnings, err := o.prepareParams(call)
 519	if err != nil {
 520		return nil, err
 521	}
 522
 523	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 524		IncludeUsage: openai.Bool(true),
 525	}
 526
 527	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 528	isActiveText := false
 529	toolCalls := make(map[int64]toolCall)
 530
 531	// Build provider metadata for streaming
 532	streamProviderMetadata := ai.ProviderMetadata{
 533		"openai": make(map[string]any),
 534	}
 535
 536	acc := openai.ChatCompletionAccumulator{}
 537	var usage ai.Usage
 538	return func(yield func(ai.StreamPart) bool) {
 539		if len(warnings) > 0 {
 540			if !yield(ai.StreamPart{
 541				Type:     ai.StreamPartTypeWarnings,
 542				Warnings: warnings,
 543			}) {
 544				return
 545			}
 546		}
 547		for stream.Next() {
 548			chunk := stream.Current()
 549			acc.AddChunk(chunk)
 550			if chunk.Usage.TotalTokens > 0 {
 551				// we do this here because the acc does not add prompt details
 552				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 553				promptTokenDetails := chunk.Usage.PromptTokensDetails
 554				usage = ai.Usage{
 555					InputTokens:     chunk.Usage.PromptTokens,
 556					OutputTokens:    chunk.Usage.CompletionTokens,
 557					TotalTokens:     chunk.Usage.TotalTokens,
 558					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 559					CacheReadTokens: promptTokenDetails.CachedTokens,
 560				}
 561
 562				// Add prediction tokens if available
 563				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 564					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 565						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 566					}
 567					if completionTokenDetails.RejectedPredictionTokens > 0 {
 568						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 569					}
 570				}
 571			}
 572			if len(chunk.Choices) == 0 {
 573				continue
 574			}
 575			for _, choice := range chunk.Choices {
 576				switch {
 577				case choice.Delta.Content != "":
 578					if !isActiveText {
 579						isActiveText = true
 580						if !yield(ai.StreamPart{
 581							Type: ai.StreamPartTypeTextStart,
 582							ID:   "0",
 583						}) {
 584							return
 585						}
 586					}
 587					if !yield(ai.StreamPart{
 588						Type:  ai.StreamPartTypeTextDelta,
 589						ID:    "0",
 590						Delta: choice.Delta.Content,
 591					}) {
 592						return
 593					}
 594				case len(choice.Delta.ToolCalls) > 0:
 595					if isActiveText {
 596						isActiveText = false
 597						if !yield(ai.StreamPart{
 598							Type: ai.StreamPartTypeTextEnd,
 599							ID:   "0",
 600						}) {
 601							return
 602						}
 603					}
 604
 605					for _, toolCallDelta := range choice.Delta.ToolCalls {
 606						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 607							if existingToolCall.hasFinished {
 608								continue
 609							}
 610							if toolCallDelta.Function.Arguments != "" {
 611								existingToolCall.arguments += toolCallDelta.Function.Arguments
 612							}
 613							if !yield(ai.StreamPart{
 614								Type:  ai.StreamPartTypeToolInputDelta,
 615								ID:    existingToolCall.id,
 616								Delta: toolCallDelta.Function.Arguments,
 617							}) {
 618								return
 619							}
 620							toolCalls[toolCallDelta.Index] = existingToolCall
 621							if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
 622								if !yield(ai.StreamPart{
 623									Type: ai.StreamPartTypeToolInputEnd,
 624									ID:   existingToolCall.id,
 625								}) {
 626									return
 627								}
 628
 629								if !yield(ai.StreamPart{
 630									Type:          ai.StreamPartTypeToolCall,
 631									ID:            existingToolCall.id,
 632									ToolCallName:  existingToolCall.name,
 633									ToolCallInput: existingToolCall.arguments,
 634								}) {
 635									return
 636								}
 637								existingToolCall.hasFinished = true
 638								toolCalls[toolCallDelta.Index] = existingToolCall
 639							}
 640						} else {
 641							// Does not exist
 642							var err error
 643							if toolCallDelta.Type != "function" {
 644								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 645							}
 646							if toolCallDelta.ID == "" {
 647								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 648							}
 649							if toolCallDelta.Function.Name == "" {
 650								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 651							}
 652							if err != nil {
 653								yield(ai.StreamPart{
 654									Type:  ai.StreamPartTypeError,
 655									Error: o.handleError(stream.Err()),
 656								})
 657								return
 658							}
 659
 660							if !yield(ai.StreamPart{
 661								Type:         ai.StreamPartTypeToolInputStart,
 662								ID:           toolCallDelta.ID,
 663								ToolCallName: toolCallDelta.Function.Name,
 664							}) {
 665								return
 666							}
 667							toolCalls[toolCallDelta.Index] = toolCall{
 668								id:        toolCallDelta.ID,
 669								name:      toolCallDelta.Function.Name,
 670								arguments: toolCallDelta.Function.Arguments,
 671							}
 672
 673							exTc := toolCalls[toolCallDelta.Index]
 674							if exTc.arguments != "" {
 675								if !yield(ai.StreamPart{
 676									Type:  ai.StreamPartTypeToolInputDelta,
 677									ID:    exTc.id,
 678									Delta: exTc.arguments,
 679								}) {
 680									return
 681								}
 682								if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
 683									if !yield(ai.StreamPart{
 684										Type: ai.StreamPartTypeToolInputEnd,
 685										ID:   toolCallDelta.ID,
 686									}) {
 687										return
 688									}
 689
 690									if !yield(ai.StreamPart{
 691										Type:          ai.StreamPartTypeToolCall,
 692										ID:            exTc.id,
 693										ToolCallName:  exTc.name,
 694										ToolCallInput: exTc.arguments,
 695									}) {
 696										return
 697									}
 698									exTc.hasFinished = true
 699									toolCalls[toolCallDelta.Index] = exTc
 700								}
 701							}
 702							continue
 703						}
 704					}
 705				}
 706			}
 707
 708			// Check for annotations in the delta's raw JSON
 709			for _, choice := range chunk.Choices {
 710				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 711					for _, annotation := range annotations {
 712						if annotation.Type == "url_citation" {
 713							if !yield(ai.StreamPart{
 714								Type:       ai.StreamPartTypeSource,
 715								ID:         uuid.NewString(),
 716								SourceType: ai.SourceTypeURL,
 717								URL:        annotation.URLCitation.URL,
 718								Title:      annotation.URLCitation.Title,
 719							}) {
 720								return
 721							}
 722						}
 723					}
 724				}
 725			}
 726		}
 727		err := stream.Err()
 728		if err == nil || errors.Is(err, io.EOF) {
 729			// finished
 730			if isActiveText {
 731				isActiveText = false
 732				if !yield(ai.StreamPart{
 733					Type: ai.StreamPartTypeTextEnd,
 734					ID:   "0",
 735				}) {
 736					return
 737				}
 738			}
 739
 740			// Add logprobs if available
 741			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 742				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
 743			}
 744
 745			// Handle annotations/citations from accumulated response
 746			if len(acc.Choices) > 0 {
 747				for _, annotation := range acc.Choices[0].Message.Annotations {
 748					if annotation.Type == "url_citation" {
 749						if !yield(ai.StreamPart{
 750							Type:       ai.StreamPartTypeSource,
 751							ID:         acc.ID,
 752							SourceType: ai.SourceTypeURL,
 753							URL:        annotation.URLCitation.URL,
 754							Title:      annotation.URLCitation.Title,
 755						}) {
 756							return
 757						}
 758					}
 759				}
 760			}
 761
 762			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 763			yield(ai.StreamPart{
 764				Type:             ai.StreamPartTypeFinish,
 765				Usage:            usage,
 766				FinishReason:     finishReason,
 767				ProviderMetadata: streamProviderMetadata,
 768			})
 769			return
 770		} else {
 771			yield(ai.StreamPart{
 772				Type:  ai.StreamPartTypeError,
 773				Error: o.handleError(err),
 774			})
 775			return
 776		}
 777	}, nil
 778}
 779
 780func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 781	switch finishReason {
 782	case "stop":
 783		return ai.FinishReasonStop
 784	case "length":
 785		return ai.FinishReasonLength
 786	case "content_filter":
 787		return ai.FinishReasonContentFilter
 788	case "function_call", "tool_calls":
 789		return ai.FinishReasonToolCalls
 790	default:
 791		return ai.FinishReasonUnknown
 792	}
 793}
 794
 795func isReasoningModel(modelID string) bool {
 796	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 797}
 798
 799func isSearchPreviewModel(modelID string) bool {
 800	return strings.Contains(modelID, "search-preview")
 801}
 802
 803func supportsFlexProcessing(modelID string) bool {
 804	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 805}
 806
 807func supportsPriorityProcessing(modelID string) bool {
 808	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 809		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 810		strings.HasPrefix(modelID, "o4-mini")
 811}
 812
 813func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 814	for _, tool := range tools {
 815		if tool.GetType() == ai.ToolTypeFunction {
 816			ft, ok := tool.(ai.FunctionTool)
 817			if !ok {
 818				continue
 819			}
 820			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 821				OfFunction: &openai.ChatCompletionFunctionToolParam{
 822					Function: shared.FunctionDefinitionParam{
 823						Name:        ft.Name,
 824						Description: param.NewOpt(ft.Description),
 825						Parameters:  openai.FunctionParameters(ft.InputSchema),
 826						Strict:      param.NewOpt(false),
 827					},
 828					Type: "function",
 829				},
 830			})
 831			continue
 832		}
 833
 834		// TODO: handle provider tool calls
 835		warnings = append(warnings, ai.CallWarning{
 836			Type:    ai.CallWarningTypeUnsupportedTool,
 837			Tool:    tool,
 838			Message: "tool is not supported",
 839		})
 840	}
 841	if toolChoice == nil {
 842		return
 843	}
 844
 845	switch *toolChoice {
 846	case ai.ToolChoiceAuto:
 847		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 848			OfAuto: param.NewOpt("auto"),
 849		}
 850	case ai.ToolChoiceNone:
 851		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 852			OfAuto: param.NewOpt("none"),
 853		}
 854	default:
 855		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 856			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 857				Type: "function",
 858				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 859					Name: string(*toolChoice),
 860				},
 861			},
 862		}
 863	}
 864	return
 865}
 866
 867func toOpenAiPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 868	var messages []openai.ChatCompletionMessageParamUnion
 869	var warnings []ai.CallWarning
 870	for _, msg := range prompt {
 871		switch msg.Role {
 872		case ai.MessageRoleSystem:
 873			var systemPromptParts []string
 874			for _, c := range msg.Content {
 875				if c.GetType() != ai.ContentTypeText {
 876					warnings = append(warnings, ai.CallWarning{
 877						Type:    ai.CallWarningTypeOther,
 878						Message: "system prompt can only have text content",
 879					})
 880					continue
 881				}
 882				textPart, ok := ai.AsContentType[ai.TextPart](c)
 883				if !ok {
 884					warnings = append(warnings, ai.CallWarning{
 885						Type:    ai.CallWarningTypeOther,
 886						Message: "system prompt text part does not have the right type",
 887					})
 888					continue
 889				}
 890				text := textPart.Text
 891				if strings.TrimSpace(text) != "" {
 892					systemPromptParts = append(systemPromptParts, textPart.Text)
 893				}
 894			}
 895			if len(systemPromptParts) == 0 {
 896				warnings = append(warnings, ai.CallWarning{
 897					Type:    ai.CallWarningTypeOther,
 898					Message: "system prompt has no text parts",
 899				})
 900				continue
 901			}
 902			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 903		case ai.MessageRoleUser:
 904			// simple user message just text content
 905			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 906				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 907				if !ok {
 908					warnings = append(warnings, ai.CallWarning{
 909						Type:    ai.CallWarningTypeOther,
 910						Message: "user message text part does not have the right type",
 911					})
 912					continue
 913				}
 914				messages = append(messages, openai.UserMessage(textPart.Text))
 915				continue
 916			}
 917			// text content and attachments
 918			// for now we only support image content later we need to check
 919			// TODO: add the supported media types to the language model so we
 920			//  can use that to validate the data here.
 921			var content []openai.ChatCompletionContentPartUnionParam
 922			for _, c := range msg.Content {
 923				switch c.GetType() {
 924				case ai.ContentTypeText:
 925					textPart, ok := ai.AsContentType[ai.TextPart](c)
 926					if !ok {
 927						warnings = append(warnings, ai.CallWarning{
 928							Type:    ai.CallWarningTypeOther,
 929							Message: "user message text part does not have the right type",
 930						})
 931						continue
 932					}
 933					content = append(content, openai.ChatCompletionContentPartUnionParam{
 934						OfText: &openai.ChatCompletionContentPartTextParam{
 935							Text: textPart.Text,
 936						},
 937					})
 938				case ai.ContentTypeFile:
 939					filePart, ok := ai.AsContentType[ai.FilePart](c)
 940					if !ok {
 941						warnings = append(warnings, ai.CallWarning{
 942							Type:    ai.CallWarningTypeOther,
 943							Message: "user message file part does not have the right type",
 944						})
 945						continue
 946					}
 947
 948					switch {
 949					case strings.HasPrefix(filePart.MediaType, "image/"):
 950						// Handle image files
 951						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 952						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 953						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 954
 955						// Check for provider-specific options like image detail
 956						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 957							if detail, ok := providerOptions["imageDetail"].(string); ok {
 958								imageURL.Detail = detail
 959							}
 960						}
 961
 962						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 963						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 964
 965					case filePart.MediaType == "audio/wav":
 966						// Handle WAV audio files
 967						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 968						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 969							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 970								Data:   base64Encoded,
 971								Format: "wav",
 972							},
 973						}
 974						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 975
 976					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 977						// Handle MP3 audio files
 978						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 979						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 980							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 981								Data:   base64Encoded,
 982								Format: "mp3",
 983							},
 984						}
 985						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 986
 987					case filePart.MediaType == "application/pdf":
 988						// Handle PDF files
 989						dataStr := string(filePart.Data)
 990
 991						// Check if data looks like a file ID (starts with "file-")
 992						if strings.HasPrefix(dataStr, "file-") {
 993							fileBlock := openai.ChatCompletionContentPartFileParam{
 994								File: openai.ChatCompletionContentPartFileFileParam{
 995									FileID: param.NewOpt(dataStr),
 996								},
 997							}
 998							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 999						} else {
1000							// Handle as base64 data
1001							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1002							data := "data:application/pdf;base64," + base64Encoded
1003
1004							filename := filePart.Filename
1005							if filename == "" {
1006								// Generate default filename based on content index
1007								filename = fmt.Sprintf("part-%d.pdf", len(content))
1008							}
1009
1010							fileBlock := openai.ChatCompletionContentPartFileParam{
1011								File: openai.ChatCompletionContentPartFileFileParam{
1012									Filename: param.NewOpt(filename),
1013									FileData: param.NewOpt(data),
1014								},
1015							}
1016							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1017						}
1018
1019					default:
1020						warnings = append(warnings, ai.CallWarning{
1021							Type:    ai.CallWarningTypeOther,
1022							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1023						})
1024					}
1025				}
1026			}
1027			messages = append(messages, openai.UserMessage(content))
1028		case ai.MessageRoleAssistant:
1029			// simple assistant message just text content
1030			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1031				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1032				if !ok {
1033					warnings = append(warnings, ai.CallWarning{
1034						Type:    ai.CallWarningTypeOther,
1035						Message: "assistant message text part does not have the right type",
1036					})
1037					continue
1038				}
1039				messages = append(messages, openai.AssistantMessage(textPart.Text))
1040				continue
1041			}
1042			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1043				Role: "assistant",
1044			}
1045			for _, c := range msg.Content {
1046				switch c.GetType() {
1047				case ai.ContentTypeText:
1048					textPart, ok := ai.AsContentType[ai.TextPart](c)
1049					if !ok {
1050						warnings = append(warnings, ai.CallWarning{
1051							Type:    ai.CallWarningTypeOther,
1052							Message: "assistant message text part does not have the right type",
1053						})
1054						continue
1055					}
1056					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1057						OfString: param.NewOpt(textPart.Text),
1058					}
1059				case ai.ContentTypeToolCall:
1060					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1061					if !ok {
1062						warnings = append(warnings, ai.CallWarning{
1063							Type:    ai.CallWarningTypeOther,
1064							Message: "assistant message tool part does not have the right type",
1065						})
1066						continue
1067					}
1068					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1069						openai.ChatCompletionMessageToolCallUnionParam{
1070							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1071								ID:   toolCallPart.ToolCallID,
1072								Type: "function",
1073								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1074									Name:      toolCallPart.ToolName,
1075									Arguments: toolCallPart.Input,
1076								},
1077							},
1078						})
1079				}
1080			}
1081			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1082				OfAssistant: &assistantMsg,
1083			})
1084		case ai.MessageRoleTool:
1085			for _, c := range msg.Content {
1086				if c.GetType() != ai.ContentTypeToolResult {
1087					warnings = append(warnings, ai.CallWarning{
1088						Type:    ai.CallWarningTypeOther,
1089						Message: "tool message can only have tool result content",
1090					})
1091					continue
1092				}
1093
1094				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1095				if !ok {
1096					warnings = append(warnings, ai.CallWarning{
1097						Type:    ai.CallWarningTypeOther,
1098						Message: "tool message result part does not have the right type",
1099					})
1100					continue
1101				}
1102
1103				switch toolResultPart.Output.GetType() {
1104				case ai.ToolResultContentTypeText:
1105					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1106					if !ok {
1107						warnings = append(warnings, ai.CallWarning{
1108							Type:    ai.CallWarningTypeOther,
1109							Message: "tool result output does not have the right type",
1110						})
1111						continue
1112					}
1113					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1114				case ai.ToolResultContentTypeError:
1115					// TODO: check if better handling is needed
1116					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1117					if !ok {
1118						warnings = append(warnings, ai.CallWarning{
1119							Type:    ai.CallWarningTypeOther,
1120							Message: "tool result output does not have the right type",
1121						})
1122						continue
1123					}
1124					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1125				}
1126			}
1127		}
1128	}
1129	return messages, warnings
1130}
1131
1132// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
1133func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1134	var annotations []openai.ChatCompletionMessageAnnotation
1135
1136	// Parse the raw JSON to extract annotations
1137	var deltaData map[string]any
1138	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1139		return annotations
1140	}
1141
1142	// Check if annotations exist in the delta
1143	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1144		for _, annotationData := range annotationsData {
1145			if annotationMap, ok := annotationData.(map[string]any); ok {
1146				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1147					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1148						annotation := openai.ChatCompletionMessageAnnotation{
1149							Type: "url_citation",
1150							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1151								URL:   urlCitationData["url"].(string),
1152								Title: urlCitationData["title"].(string),
1153							},
1154						}
1155						annotations = append(annotations, annotation)
1156					}
1157				}
1158			}
1159		}
1160	}
1161
1162	return annotations
1163}