openai.go

   1package providers
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"io"
  10	"maps"
  11	"strings"
  12
  13	"github.com/charmbracelet/crush/internal/ai"
  14	"github.com/google/uuid"
  15	"github.com/openai/openai-go/v2"
  16	"github.com/openai/openai-go/v2/option"
  17	"github.com/openai/openai-go/v2/packages/param"
  18	"github.com/openai/openai-go/v2/shared"
  19)
  20
  21type ReasoningEffort string
  22
  23const (
  24	ReasoningEffortMinimal ReasoningEffort = "minimal"
  25	ReasoningEffortLow     ReasoningEffort = "low"
  26	ReasoningEffortMedium  ReasoningEffort = "medium"
  27	ReasoningEffortHigh    ReasoningEffort = "high"
  28)
  29
  30type OpenAIProviderOptions struct {
  31	LogitBias           map[string]int64 `json:"logit_bias"`
  32	LogProbs            *bool            `json:"log_probes"`
  33	TopLogProbs         *int64           `json:"top_log_probs"`
  34	ParallelToolCalls   *bool            `json:"parallel_tool_calls"`
  35	User                *string          `json:"user"`
  36	ReasoningEffort     *ReasoningEffort `json:"reasoning_effort"`
  37	MaxCompletionTokens *int64           `json:"max_completion_tokens"`
  38	TextVerbosity       *string          `json:"text_verbosity"`
  39	Prediction          map[string]any   `json:"prediction"`
  40	Store               *bool            `json:"store"`
  41	Metadata            map[string]any   `json:"metadata"`
  42	PromptCacheKey      *string          `json:"prompt_cache_key"`
  43	SafetyIdentifier    *string          `json:"safety_identifier"`
  44	ServiceTier         *string          `json:"service_tier"`
  45	StructuredOutputs   *bool            `json:"structured_outputs"`
  46}
  47
  48type openAIProvider struct {
  49	options openAIProviderOptions
  50}
  51
  52type openAIProviderOptions struct {
  53	baseURL      string
  54	apiKey       string
  55	organization string
  56	project      string
  57	name         string
  58	headers      map[string]string
  59	client       option.HTTPClient
  60}
  61
  62type OpenAIOption = func(*openAIProviderOptions)
  63
  64func NewOpenAIProvider(opts ...OpenAIOption) ai.Provider {
  65	options := openAIProviderOptions{
  66		headers: map[string]string{},
  67	}
  68	for _, o := range opts {
  69		o(&options)
  70	}
  71
  72	if options.baseURL == "" {
  73		options.baseURL = "https://api.openai.com/v1"
  74	}
  75
  76	if options.name == "" {
  77		options.name = "openai"
  78	}
  79
  80	if options.organization != "" {
  81		options.headers["OpenAI-Organization"] = options.organization
  82	}
  83
  84	if options.project != "" {
  85		options.headers["OpenAI-Project"] = options.project
  86	}
  87
  88	return &openAIProvider{
  89		options: options,
  90	}
  91}
  92
  93func WithOpenAIBaseURL(baseURL string) OpenAIOption {
  94	return func(o *openAIProviderOptions) {
  95		o.baseURL = baseURL
  96	}
  97}
  98
  99func WithOpenAIApiKey(apiKey string) OpenAIOption {
 100	return func(o *openAIProviderOptions) {
 101		o.apiKey = apiKey
 102	}
 103}
 104
 105func WithOpenAIOrganization(organization string) OpenAIOption {
 106	return func(o *openAIProviderOptions) {
 107		o.organization = organization
 108	}
 109}
 110
 111func WithOpenAIProject(project string) OpenAIOption {
 112	return func(o *openAIProviderOptions) {
 113		o.project = project
 114	}
 115}
 116
 117func WithOpenAIName(name string) OpenAIOption {
 118	return func(o *openAIProviderOptions) {
 119		o.name = name
 120	}
 121}
 122
 123func WithOpenAIHeaders(headers map[string]string) OpenAIOption {
 124	return func(o *openAIProviderOptions) {
 125		maps.Copy(o.headers, headers)
 126	}
 127}
 128
 129func WithOpenAIHttpClient(client option.HTTPClient) OpenAIOption {
 130	return func(o *openAIProviderOptions) {
 131		o.client = client
 132	}
 133}
 134
 135// LanguageModel implements ai.Provider.
 136func (o *openAIProvider) LanguageModel(modelID string) ai.LanguageModel {
 137	openaiClientOptions := []option.RequestOption{}
 138	if o.options.apiKey != "" {
 139		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 140	}
 141	if o.options.baseURL != "" {
 142		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 143	}
 144
 145	for key, value := range o.options.headers {
 146		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 147	}
 148
 149	if o.options.client != nil {
 150		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 151	}
 152
 153	return openAILanguageModel{
 154		modelID:         modelID,
 155		provider:        fmt.Sprintf("%s.chat", o.options.name),
 156		providerOptions: o.options,
 157		client:          openai.NewClient(openaiClientOptions...),
 158	}
 159}
 160
 161type openAILanguageModel struct {
 162	provider        string
 163	modelID         string
 164	client          openai.Client
 165	providerOptions openAIProviderOptions
 166}
 167
 168// Model implements ai.LanguageModel.
 169func (o openAILanguageModel) Model() string {
 170	return o.modelID
 171}
 172
 173// Provider implements ai.LanguageModel.
 174func (o openAILanguageModel) Provider() string {
 175	return o.provider
 176}
 177
 178func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 179	params := &openai.ChatCompletionNewParams{}
 180	messages, warnings := toOpenAIPrompt(call.Prompt)
 181	providerOptions := &OpenAIProviderOptions{}
 182	if v, ok := call.ProviderOptions["openai"]; ok {
 183		err := ai.ParseOptions(v, providerOptions)
 184		if err != nil {
 185			return nil, nil, err
 186		}
 187	}
 188	if call.TopK != nil {
 189		warnings = append(warnings, ai.CallWarning{
 190			Type:    ai.CallWarningTypeUnsupportedSetting,
 191			Setting: "top_k",
 192		})
 193	}
 194	params.Messages = messages
 195	params.Model = o.modelID
 196	if providerOptions.LogitBias != nil {
 197		params.LogitBias = providerOptions.LogitBias
 198	}
 199	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 200		providerOptions.LogProbs = nil
 201	}
 202	if providerOptions.LogProbs != nil {
 203		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 204	}
 205	if providerOptions.TopLogProbs != nil {
 206		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 207	}
 208	if providerOptions.User != nil {
 209		params.User = param.NewOpt(*providerOptions.User)
 210	}
 211	if providerOptions.ParallelToolCalls != nil {
 212		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 213	}
 214
 215	if call.MaxOutputTokens != nil {
 216		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 217	}
 218	if call.Temperature != nil {
 219		params.Temperature = param.NewOpt(*call.Temperature)
 220	}
 221	if call.TopP != nil {
 222		params.TopP = param.NewOpt(*call.TopP)
 223	}
 224	if call.FrequencyPenalty != nil {
 225		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 226	}
 227	if call.PresencePenalty != nil {
 228		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 229	}
 230
 231	if providerOptions.MaxCompletionTokens != nil {
 232		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 233	}
 234
 235	if providerOptions.TextVerbosity != nil {
 236		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 237	}
 238	if providerOptions.Prediction != nil {
 239		// Convert map[string]any to ChatCompletionPredictionContentParam
 240		if content, ok := providerOptions.Prediction["content"]; ok {
 241			if contentStr, ok := content.(string); ok {
 242				params.Prediction = openai.ChatCompletionPredictionContentParam{
 243					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 244						OfString: param.NewOpt(contentStr),
 245					},
 246				}
 247			}
 248		}
 249	}
 250	if providerOptions.Store != nil {
 251		params.Store = param.NewOpt(*providerOptions.Store)
 252	}
 253	if providerOptions.Metadata != nil {
 254		// Convert map[string]any to map[string]string
 255		metadata := make(map[string]string)
 256		for k, v := range providerOptions.Metadata {
 257			if str, ok := v.(string); ok {
 258				metadata[k] = str
 259			}
 260		}
 261		params.Metadata = metadata
 262	}
 263	if providerOptions.PromptCacheKey != nil {
 264		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 265	}
 266	if providerOptions.SafetyIdentifier != nil {
 267		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 268	}
 269	if providerOptions.ServiceTier != nil {
 270		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 271	}
 272
 273	if providerOptions.ReasoningEffort != nil {
 274		switch *providerOptions.ReasoningEffort {
 275		case ReasoningEffortMinimal:
 276			params.ReasoningEffort = shared.ReasoningEffortMinimal
 277		case ReasoningEffortLow:
 278			params.ReasoningEffort = shared.ReasoningEffortLow
 279		case ReasoningEffortMedium:
 280			params.ReasoningEffort = shared.ReasoningEffortMedium
 281		case ReasoningEffortHigh:
 282			params.ReasoningEffort = shared.ReasoningEffortHigh
 283		default:
 284			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 285		}
 286	}
 287
 288	if isReasoningModel(o.modelID) {
 289		// remove unsupported settings for reasoning models
 290		// see https://platform.openai.com/docs/guides/reasoning#limitations
 291		if call.Temperature != nil {
 292			params.Temperature = param.Opt[float64]{}
 293			warnings = append(warnings, ai.CallWarning{
 294				Type:    ai.CallWarningTypeUnsupportedSetting,
 295				Setting: "temperature",
 296				Details: "temperature is not supported for reasoning models",
 297			})
 298		}
 299		if call.TopP != nil {
 300			params.TopP = param.Opt[float64]{}
 301			warnings = append(warnings, ai.CallWarning{
 302				Type:    ai.CallWarningTypeUnsupportedSetting,
 303				Setting: "top_p",
 304				Details: "topP is not supported for reasoning models",
 305			})
 306		}
 307		if call.FrequencyPenalty != nil {
 308			params.FrequencyPenalty = param.Opt[float64]{}
 309			warnings = append(warnings, ai.CallWarning{
 310				Type:    ai.CallWarningTypeUnsupportedSetting,
 311				Setting: "frequency_penalty",
 312				Details: "frequencyPenalty is not supported for reasoning models",
 313			})
 314		}
 315		if call.PresencePenalty != nil {
 316			params.PresencePenalty = param.Opt[float64]{}
 317			warnings = append(warnings, ai.CallWarning{
 318				Type:    ai.CallWarningTypeUnsupportedSetting,
 319				Setting: "presence_penalty",
 320				Details: "presencePenalty is not supported for reasoning models",
 321			})
 322		}
 323		if providerOptions.LogitBias != nil {
 324			params.LogitBias = nil
 325			warnings = append(warnings, ai.CallWarning{
 326				Type:    ai.CallWarningTypeOther,
 327				Message: "logitBias is not supported for reasoning models",
 328			})
 329		}
 330		if providerOptions.LogProbs != nil {
 331			params.Logprobs = param.Opt[bool]{}
 332			warnings = append(warnings, ai.CallWarning{
 333				Type:    ai.CallWarningTypeOther,
 334				Message: "logprobs is not supported for reasoning models",
 335			})
 336		}
 337		if providerOptions.TopLogProbs != nil {
 338			params.TopLogprobs = param.Opt[int64]{}
 339			warnings = append(warnings, ai.CallWarning{
 340				Type:    ai.CallWarningTypeOther,
 341				Message: "topLogprobs is not supported for reasoning models",
 342			})
 343		}
 344
 345		// reasoning models use max_completion_tokens instead of max_tokens
 346		if call.MaxOutputTokens != nil {
 347			if providerOptions.MaxCompletionTokens == nil {
 348				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 349			}
 350			params.MaxTokens = param.Opt[int64]{}
 351		}
 352	}
 353
 354	// Handle search preview models
 355	if isSearchPreviewModel(o.modelID) {
 356		if call.Temperature != nil {
 357			params.Temperature = param.Opt[float64]{}
 358			warnings = append(warnings, ai.CallWarning{
 359				Type:    ai.CallWarningTypeUnsupportedSetting,
 360				Setting: "temperature",
 361				Details: "temperature is not supported for the search preview models and has been removed.",
 362			})
 363		}
 364	}
 365
 366	// Handle service tier validation
 367	if providerOptions.ServiceTier != nil {
 368		serviceTier := *providerOptions.ServiceTier
 369		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 370			params.ServiceTier = ""
 371			warnings = append(warnings, ai.CallWarning{
 372				Type:    ai.CallWarningTypeUnsupportedSetting,
 373				Setting: "serviceTier",
 374				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 375			})
 376		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 377			params.ServiceTier = ""
 378			warnings = append(warnings, ai.CallWarning{
 379				Type:    ai.CallWarningTypeUnsupportedSetting,
 380				Setting: "serviceTier",
 381				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 382			})
 383		}
 384	}
 385
 386	if len(call.Tools) > 0 {
 387		tools, toolChoice, toolWarnings := toOpenAITools(call.Tools, call.ToolChoice)
 388		params.Tools = tools
 389		if toolChoice != nil {
 390			params.ToolChoice = *toolChoice
 391		}
 392		warnings = append(warnings, toolWarnings...)
 393	}
 394	return params, warnings, nil
 395}
 396
 397// Generate implements ai.LanguageModel.
 398func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 399	params, warnings, err := o.prepareParams(call)
 400	if err != nil {
 401		return nil, err
 402	}
 403	response, err := o.client.Chat.Completions.New(ctx, *params)
 404	if err != nil {
 405		return nil, err
 406	}
 407
 408	if len(response.Choices) == 0 {
 409		return nil, errors.New("no response generated")
 410	}
 411	choice := response.Choices[0]
 412	var content []ai.Content
 413	text := choice.Message.Content
 414	if text != "" {
 415		content = append(content, ai.TextContent{
 416			Text: text,
 417		})
 418	}
 419
 420	for _, tc := range choice.Message.ToolCalls {
 421		toolCallID := tc.ID
 422		if toolCallID == "" {
 423			toolCallID = uuid.NewString()
 424		}
 425		content = append(content, ai.ToolCallContent{
 426			ProviderExecuted: false, // TODO: update when handling other tools
 427			ToolCallID:       toolCallID,
 428			ToolName:         tc.Function.Name,
 429			Input:            tc.Function.Arguments,
 430		})
 431	}
 432	// Handle annotations/citations
 433	for _, annotation := range choice.Message.Annotations {
 434		if annotation.Type == "url_citation" {
 435			content = append(content, ai.SourceContent{
 436				SourceType: ai.SourceTypeURL,
 437				ID:         uuid.NewString(),
 438				URL:        annotation.URLCitation.URL,
 439				Title:      annotation.URLCitation.Title,
 440			})
 441		}
 442	}
 443
 444	completionTokenDetails := response.Usage.CompletionTokensDetails
 445	promptTokenDetails := response.Usage.PromptTokensDetails
 446
 447	// Build provider metadata
 448	providerMetadata := ai.ProviderMetadata{
 449		"openai": make(map[string]any),
 450	}
 451
 452	// Add logprobs if available
 453	if len(choice.Logprobs.Content) > 0 {
 454		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
 455	}
 456
 457	// Add prediction tokens if available
 458	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 459		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 460			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 461		}
 462		if completionTokenDetails.RejectedPredictionTokens > 0 {
 463			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 464		}
 465	}
 466
 467	return &ai.Response{
 468		Content: content,
 469		Usage: ai.Usage{
 470			InputTokens:     response.Usage.PromptTokens,
 471			OutputTokens:    response.Usage.CompletionTokens,
 472			TotalTokens:     response.Usage.TotalTokens,
 473			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 474			CacheReadTokens: promptTokenDetails.CachedTokens,
 475		},
 476		FinishReason:     mapOpenAIFinishReason(choice.FinishReason),
 477		ProviderMetadata: providerMetadata,
 478		Warnings:         warnings,
 479	}, nil
 480}
 481
 482type toolCall struct {
 483	id          string
 484	name        string
 485	arguments   string
 486	hasFinished bool
 487}
 488
 489// Stream implements ai.LanguageModel.
 490func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 491	params, warnings, err := o.prepareParams(call)
 492	if err != nil {
 493		return nil, err
 494	}
 495
 496	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 497		IncludeUsage: openai.Bool(true),
 498	}
 499
 500	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 501	isActiveText := false
 502	toolCalls := make(map[int64]toolCall)
 503
 504	// Build provider metadata for streaming
 505	streamProviderMetadata := ai.ProviderOptions{
 506		"openai": make(map[string]any),
 507	}
 508
 509	acc := openai.ChatCompletionAccumulator{}
 510	var usage ai.Usage
 511	return func(yield func(ai.StreamPart) bool) {
 512		if len(warnings) > 0 {
 513			if !yield(ai.StreamPart{
 514				Type:     ai.StreamPartTypeWarnings,
 515				Warnings: warnings,
 516			}) {
 517				return
 518			}
 519		}
 520		for stream.Next() {
 521			chunk := stream.Current()
 522			acc.AddChunk(chunk)
 523			if chunk.Usage.TotalTokens > 0 {
 524				// we do this here because the acc does not add prompt details
 525				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 526				promptTokenDetails := chunk.Usage.PromptTokensDetails
 527				usage = ai.Usage{
 528					InputTokens:     chunk.Usage.PromptTokens,
 529					OutputTokens:    chunk.Usage.CompletionTokens,
 530					TotalTokens:     chunk.Usage.TotalTokens,
 531					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 532					CacheReadTokens: promptTokenDetails.CachedTokens,
 533				}
 534
 535				// Add prediction tokens if available
 536				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 537					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 538						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 539					}
 540					if completionTokenDetails.RejectedPredictionTokens > 0 {
 541						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 542					}
 543				}
 544			}
 545			if len(chunk.Choices) == 0 {
 546				continue
 547			}
 548			for _, choice := range chunk.Choices {
 549				switch {
 550				case choice.Delta.Content != "":
 551					if !isActiveText {
 552						isActiveText = true
 553						if !yield(ai.StreamPart{
 554							Type: ai.StreamPartTypeTextStart,
 555							ID:   "0",
 556						}) {
 557							return
 558						}
 559					}
 560					if !yield(ai.StreamPart{
 561						Type:  ai.StreamPartTypeTextDelta,
 562						ID:    "0",
 563						Delta: choice.Delta.Content,
 564					}) {
 565						return
 566					}
 567				case len(choice.Delta.ToolCalls) > 0:
 568					if isActiveText {
 569						isActiveText = false
 570						if !yield(ai.StreamPart{
 571							Type: ai.StreamPartTypeTextEnd,
 572							ID:   "0",
 573						}) {
 574							return
 575						}
 576					}
 577
 578					for _, toolCallDelta := range choice.Delta.ToolCalls {
 579						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 580							if existingToolCall.hasFinished {
 581								continue
 582							}
 583							if toolCallDelta.Function.Arguments != "" {
 584								existingToolCall.arguments += toolCallDelta.Function.Arguments
 585							}
 586							if !yield(ai.StreamPart{
 587								Type:  ai.StreamPartTypeToolInputDelta,
 588								ID:    existingToolCall.id,
 589								Delta: toolCallDelta.Function.Arguments,
 590							}) {
 591								return
 592							}
 593							toolCalls[toolCallDelta.Index] = existingToolCall
 594							if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
 595								if !yield(ai.StreamPart{
 596									Type: ai.StreamPartTypeToolInputEnd,
 597									ID:   existingToolCall.id,
 598								}) {
 599									return
 600								}
 601
 602								if !yield(ai.StreamPart{
 603									Type:          ai.StreamPartTypeToolCall,
 604									ID:            existingToolCall.id,
 605									ToolCallName:  existingToolCall.name,
 606									ToolCallInput: existingToolCall.arguments,
 607								}) {
 608									return
 609								}
 610								existingToolCall.hasFinished = true
 611								toolCalls[toolCallDelta.Index] = existingToolCall
 612							}
 613
 614						} else {
 615							// Does not exist
 616							var err error
 617							if toolCallDelta.Type != "function" {
 618								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 619							}
 620							if toolCallDelta.ID == "" {
 621								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 622							}
 623							if toolCallDelta.Function.Name == "" {
 624								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 625							}
 626							if err != nil {
 627								yield(ai.StreamPart{
 628									Type:  ai.StreamPartTypeError,
 629									Error: stream.Err(),
 630								})
 631								return
 632							}
 633
 634							if !yield(ai.StreamPart{
 635								Type:         ai.StreamPartTypeToolInputStart,
 636								ID:           toolCallDelta.ID,
 637								ToolCallName: toolCallDelta.Function.Name,
 638							}) {
 639								return
 640							}
 641							toolCalls[toolCallDelta.Index] = toolCall{
 642								id:        toolCallDelta.ID,
 643								name:      toolCallDelta.Function.Name,
 644								arguments: toolCallDelta.Function.Arguments,
 645							}
 646
 647							exTc := toolCalls[toolCallDelta.Index]
 648							if exTc.arguments != "" {
 649								if !yield(ai.StreamPart{
 650									Type:  ai.StreamPartTypeToolInputDelta,
 651									ID:    exTc.id,
 652									Delta: exTc.arguments,
 653								}) {
 654									return
 655								}
 656								if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
 657									if !yield(ai.StreamPart{
 658										Type: ai.StreamPartTypeToolInputEnd,
 659										ID:   toolCallDelta.ID,
 660									}) {
 661										return
 662									}
 663
 664									if !yield(ai.StreamPart{
 665										Type:          ai.StreamPartTypeToolCall,
 666										ID:            exTc.id,
 667										ToolCallName:  exTc.name,
 668										ToolCallInput: exTc.arguments,
 669									}) {
 670										return
 671									}
 672									exTc.hasFinished = true
 673									toolCalls[toolCallDelta.Index] = exTc
 674								}
 675							}
 676							continue
 677						}
 678					}
 679				}
 680			}
 681
 682			// Check for annotations in the delta's raw JSON
 683			for _, choice := range chunk.Choices {
 684				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 685					for _, annotation := range annotations {
 686						if annotation.Type == "url_citation" {
 687							if !yield(ai.StreamPart{
 688								Type:       ai.StreamPartTypeSource,
 689								ID:         uuid.NewString(),
 690								SourceType: ai.SourceTypeURL,
 691								URL:        annotation.URLCitation.URL,
 692								Title:      annotation.URLCitation.Title,
 693							}) {
 694								return
 695							}
 696						}
 697					}
 698				}
 699			}
 700
 701		}
 702		err := stream.Err()
 703		if err == nil || errors.Is(err, io.EOF) {
 704			// finished
 705			if isActiveText {
 706				isActiveText = false
 707				if !yield(ai.StreamPart{
 708					Type: ai.StreamPartTypeTextEnd,
 709					ID:   "0",
 710				}) {
 711					return
 712				}
 713			}
 714
 715			// Add logprobs if available
 716			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 717				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
 718			}
 719
 720			// Handle annotations/citations from accumulated response
 721			if len(acc.Choices) > 0 {
 722				for _, annotation := range acc.Choices[0].Message.Annotations {
 723					if annotation.Type == "url_citation" {
 724						if !yield(ai.StreamPart{
 725							Type:       ai.StreamPartTypeSource,
 726							ID:         uuid.NewString(),
 727							SourceType: ai.SourceTypeURL,
 728							URL:        annotation.URLCitation.URL,
 729							Title:      annotation.URLCitation.Title,
 730						}) {
 731							return
 732						}
 733					}
 734				}
 735			}
 736
 737			finishReason := mapOpenAIFinishReason(acc.Choices[0].FinishReason)
 738			yield(ai.StreamPart{
 739				Type:             ai.StreamPartTypeFinish,
 740				Usage:            usage,
 741				FinishReason:     finishReason,
 742				ProviderMetadata: streamProviderMetadata,
 743			})
 744			return
 745
 746		} else {
 747			yield(ai.StreamPart{
 748				Type:  ai.StreamPartTypeError,
 749				Error: stream.Err(),
 750			})
 751			return
 752		}
 753	}, nil
 754}
 755
 756func mapOpenAIFinishReason(finishReason string) ai.FinishReason {
 757	switch finishReason {
 758	case "stop":
 759		return ai.FinishReasonStop
 760	case "length":
 761		return ai.FinishReasonLength
 762	case "content_filter":
 763		return ai.FinishReasonContentFilter
 764	case "function_call", "tool_calls":
 765		return ai.FinishReasonToolCalls
 766	default:
 767		return ai.FinishReasonUnknown
 768	}
 769}
 770
 771func isReasoningModel(modelID string) bool {
 772	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 773}
 774
 775func isSearchPreviewModel(modelID string) bool {
 776	return strings.Contains(modelID, "search-preview")
 777}
 778
 779func supportsFlexProcessing(modelID string) bool {
 780	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 781}
 782
 783func supportsPriorityProcessing(modelID string) bool {
 784	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 785		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 786		strings.HasPrefix(modelID, "o4-mini")
 787}
 788
 789func toOpenAITools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAITools []openai.ChatCompletionToolUnionParam, openAIToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 790	for _, tool := range tools {
 791		if tool.GetType() == ai.ToolTypeFunction {
 792			ft, ok := tool.(ai.FunctionTool)
 793			if !ok {
 794				continue
 795			}
 796			openAITools = append(openAITools, openai.ChatCompletionToolUnionParam{
 797				OfFunction: &openai.ChatCompletionFunctionToolParam{
 798					Function: shared.FunctionDefinitionParam{
 799						Name:        ft.Name,
 800						Description: param.NewOpt(ft.Description),
 801						Parameters:  openai.FunctionParameters(ft.InputSchema),
 802						Strict:      param.NewOpt(false),
 803					},
 804					Type: "function",
 805				},
 806			})
 807			continue
 808		}
 809
 810		// TODO: handle provider tool calls
 811		warnings = append(warnings, ai.CallWarning{
 812			Type:    ai.CallWarningTypeUnsupportedTool,
 813			Tool:    tool,
 814			Message: "tool is not supported",
 815		})
 816	}
 817	if toolChoice == nil {
 818		return
 819	}
 820
 821	switch *toolChoice {
 822	case ai.ToolChoiceAuto:
 823		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 824			OfAuto: param.NewOpt("auto"),
 825		}
 826	case ai.ToolChoiceNone:
 827		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 828			OfAuto: param.NewOpt("none"),
 829		}
 830	default:
 831		openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 832			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 833				Type: "function",
 834				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 835					Name: string(*toolChoice),
 836				},
 837			},
 838		}
 839	}
 840	return
 841}
 842
 843func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 844	var messages []openai.ChatCompletionMessageParamUnion
 845	var warnings []ai.CallWarning
 846	for _, msg := range prompt {
 847		switch msg.Role {
 848		case ai.MessageRoleSystem:
 849			var systemPromptParts []string
 850			for _, c := range msg.Content {
 851				if c.GetType() != ai.ContentTypeText {
 852					warnings = append(warnings, ai.CallWarning{
 853						Type:    ai.CallWarningTypeOther,
 854						Message: "system prompt can only have text content",
 855					})
 856					continue
 857				}
 858				textPart, ok := ai.AsContentType[ai.TextPart](c)
 859				if !ok {
 860					warnings = append(warnings, ai.CallWarning{
 861						Type:    ai.CallWarningTypeOther,
 862						Message: "system prompt text part does not have the right type",
 863					})
 864					continue
 865				}
 866				text := textPart.Text
 867				if strings.TrimSpace(text) != "" {
 868					systemPromptParts = append(systemPromptParts, textPart.Text)
 869				}
 870			}
 871			if len(systemPromptParts) == 0 {
 872				warnings = append(warnings, ai.CallWarning{
 873					Type:    ai.CallWarningTypeOther,
 874					Message: "system prompt has no text parts",
 875				})
 876				continue
 877			}
 878			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 879		case ai.MessageRoleUser:
 880			// simple user message just text content
 881			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 882				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 883				if !ok {
 884					warnings = append(warnings, ai.CallWarning{
 885						Type:    ai.CallWarningTypeOther,
 886						Message: "user message text part does not have the right type",
 887					})
 888					continue
 889				}
 890				messages = append(messages, openai.UserMessage(textPart.Text))
 891				continue
 892			}
 893			// text content and attachments
 894			// for now we only support image content later we need to check
 895			// TODO: add the supported media types to the language model so we
 896			//  can use that to validate the data here.
 897			var content []openai.ChatCompletionContentPartUnionParam
 898			for _, c := range msg.Content {
 899				switch c.GetType() {
 900				case ai.ContentTypeText:
 901					textPart, ok := ai.AsContentType[ai.TextPart](c)
 902					if !ok {
 903						warnings = append(warnings, ai.CallWarning{
 904							Type:    ai.CallWarningTypeOther,
 905							Message: "user message text part does not have the right type",
 906						})
 907						continue
 908					}
 909					content = append(content, openai.ChatCompletionContentPartUnionParam{
 910						OfText: &openai.ChatCompletionContentPartTextParam{
 911							Text: textPart.Text,
 912						},
 913					})
 914				case ai.ContentTypeFile:
 915					filePart, ok := ai.AsContentType[ai.FilePart](c)
 916					if !ok {
 917						warnings = append(warnings, ai.CallWarning{
 918							Type:    ai.CallWarningTypeOther,
 919							Message: "user message file part does not have the right type",
 920						})
 921						continue
 922					}
 923
 924					switch {
 925					case strings.HasPrefix(filePart.MediaType, "image/"):
 926						// Handle image files
 927						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 928						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 929						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 930
 931						// Check for provider-specific options like image detail
 932						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 933							if detail, ok := providerOptions["imageDetail"].(string); ok {
 934								imageURL.Detail = detail
 935							}
 936						}
 937
 938						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 939						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 940
 941					case filePart.MediaType == "audio/wav":
 942						// Handle WAV audio files
 943						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 944						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 945							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 946								Data:   base64Encoded,
 947								Format: "wav",
 948							},
 949						}
 950						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 951
 952					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 953						// Handle MP3 audio files
 954						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 955						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 956							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 957								Data:   base64Encoded,
 958								Format: "mp3",
 959							},
 960						}
 961						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 962
 963					case filePart.MediaType == "application/pdf":
 964						// Handle PDF files
 965						dataStr := string(filePart.Data)
 966
 967						// Check if data looks like a file ID (starts with "file-")
 968						if strings.HasPrefix(dataStr, "file-") {
 969							fileBlock := openai.ChatCompletionContentPartFileParam{
 970								File: openai.ChatCompletionContentPartFileFileParam{
 971									FileID: param.NewOpt(dataStr),
 972								},
 973							}
 974							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 975						} else {
 976							// Handle as base64 data
 977							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 978							data := "data:application/pdf;base64," + base64Encoded
 979
 980							filename := filePart.Filename
 981							if filename == "" {
 982								// Generate default filename based on content index
 983								filename = fmt.Sprintf("part-%d.pdf", len(content))
 984							}
 985
 986							fileBlock := openai.ChatCompletionContentPartFileParam{
 987								File: openai.ChatCompletionContentPartFileFileParam{
 988									Filename: param.NewOpt(filename),
 989									FileData: param.NewOpt(data),
 990								},
 991							}
 992							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 993						}
 994
 995					default:
 996						warnings = append(warnings, ai.CallWarning{
 997							Type:    ai.CallWarningTypeOther,
 998							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
 999						})
1000					}
1001				}
1002			}
1003			messages = append(messages, openai.UserMessage(content))
1004		case ai.MessageRoleAssistant:
1005			// simple assistant message just text content
1006			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1007				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1008				if !ok {
1009					warnings = append(warnings, ai.CallWarning{
1010						Type:    ai.CallWarningTypeOther,
1011						Message: "assistant message text part does not have the right type",
1012					})
1013					continue
1014				}
1015				messages = append(messages, openai.AssistantMessage(textPart.Text))
1016				continue
1017			}
1018			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1019				Role: "assistant",
1020			}
1021			for _, c := range msg.Content {
1022				switch c.GetType() {
1023				case ai.ContentTypeText:
1024					textPart, ok := ai.AsContentType[ai.TextPart](c)
1025					if !ok {
1026						warnings = append(warnings, ai.CallWarning{
1027							Type:    ai.CallWarningTypeOther,
1028							Message: "assistant message text part does not have the right type",
1029						})
1030						continue
1031					}
1032					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1033						OfString: param.NewOpt(textPart.Text),
1034					}
1035				case ai.ContentTypeToolCall:
1036					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1037					if !ok {
1038						warnings = append(warnings, ai.CallWarning{
1039							Type:    ai.CallWarningTypeOther,
1040							Message: "assistant message tool part does not have the right type",
1041						})
1042						continue
1043					}
1044					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1045						openai.ChatCompletionMessageToolCallUnionParam{
1046							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1047								ID:   toolCallPart.ToolCallID,
1048								Type: "function",
1049								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1050									Name:      toolCallPart.ToolName,
1051									Arguments: toolCallPart.Input,
1052								},
1053							},
1054						})
1055				}
1056			}
1057			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1058				OfAssistant: &assistantMsg,
1059			})
1060		case ai.MessageRoleTool:
1061			for _, c := range msg.Content {
1062				if c.GetType() != ai.ContentTypeToolResult {
1063					warnings = append(warnings, ai.CallWarning{
1064						Type:    ai.CallWarningTypeOther,
1065						Message: "tool message can only have tool result content",
1066					})
1067					continue
1068				}
1069
1070				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1071				if !ok {
1072					warnings = append(warnings, ai.CallWarning{
1073						Type:    ai.CallWarningTypeOther,
1074						Message: "tool message result part does not have the right type",
1075					})
1076					continue
1077				}
1078
1079				switch toolResultPart.Output.GetType() {
1080				case ai.ToolResultContentTypeText:
1081					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1082					if !ok {
1083						warnings = append(warnings, ai.CallWarning{
1084							Type:    ai.CallWarningTypeOther,
1085							Message: "tool result output does not have the right type",
1086						})
1087						continue
1088					}
1089					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1090				case ai.ToolResultContentTypeError:
1091					// TODO: check if better handling is needed
1092					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1093					if !ok {
1094						warnings = append(warnings, ai.CallWarning{
1095							Type:    ai.CallWarningTypeOther,
1096							Message: "tool result output does not have the right type",
1097						})
1098						continue
1099					}
1100					messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID))
1101				}
1102			}
1103		}
1104	}
1105	return messages, warnings
1106}
1107
1108// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
1109func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1110	var annotations []openai.ChatCompletionMessageAnnotation
1111
1112	// Parse the raw JSON to extract annotations
1113	var deltaData map[string]interface{}
1114	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1115		return annotations
1116	}
1117
1118	// Check if annotations exist in the delta
1119	if annotationsData, ok := deltaData["annotations"].([]interface{}); ok {
1120		for _, annotationData := range annotationsData {
1121			if annotationMap, ok := annotationData.(map[string]interface{}); ok {
1122				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1123					if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok {
1124						annotation := openai.ChatCompletionMessageAnnotation{
1125							Type: "url_citation",
1126							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1127								URL:   urlCitationData["url"].(string),
1128								Title: urlCitationData["title"].(string),
1129							},
1130						}
1131						annotations = append(annotations, annotation)
1132					}
1133				}
1134			}
1135		}
1136	}
1137
1138	return annotations
1139}