openai.go

   1package openai
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"io"
  10	"maps"
  11	"strings"
  12
  13	"github.com/charmbracelet/ai"
  14	"github.com/charmbracelet/ai/internal/jsonext"
  15	"github.com/google/uuid"
  16	"github.com/openai/openai-go/v2"
  17	"github.com/openai/openai-go/v2/option"
  18	"github.com/openai/openai-go/v2/packages/param"
  19	"github.com/openai/openai-go/v2/shared"
  20)
  21
  22type ReasoningEffort string
  23
  24const (
  25	ReasoningEffortMinimal ReasoningEffort = "minimal"
  26	ReasoningEffortLow     ReasoningEffort = "low"
  27	ReasoningEffortMedium  ReasoningEffort = "medium"
  28	ReasoningEffortHigh    ReasoningEffort = "high"
  29)
  30
  31type ProviderOptions struct {
  32	LogitBias           map[string]int64 `json:"logit_bias"`
  33	LogProbs            *bool            `json:"log_probes"`
  34	TopLogProbs         *int64           `json:"top_log_probs"`
  35	ParallelToolCalls   *bool            `json:"parallel_tool_calls"`
  36	User                *string          `json:"user"`
  37	ReasoningEffort     *ReasoningEffort `json:"reasoning_effort"`
  38	MaxCompletionTokens *int64           `json:"max_completion_tokens"`
  39	TextVerbosity       *string          `json:"text_verbosity"`
  40	Prediction          map[string]any   `json:"prediction"`
  41	Store               *bool            `json:"store"`
  42	Metadata            map[string]any   `json:"metadata"`
  43	PromptCacheKey      *string          `json:"prompt_cache_key"`
  44	SafetyIdentifier    *string          `json:"safety_identifier"`
  45	ServiceTier         *string          `json:"service_tier"`
  46	StructuredOutputs   *bool            `json:"structured_outputs"`
  47}
  48
  49type provider struct {
  50	options options
  51}
  52
  53type options struct {
  54	baseURL      string
  55	apiKey       string
  56	organization string
  57	project      string
  58	name         string
  59	headers      map[string]string
  60	client       option.HTTPClient
  61}
  62
  63type Option = func(*options)
  64
  65func New(opts ...Option) ai.Provider {
  66	options := options{
  67		headers: map[string]string{},
  68	}
  69	for _, o := range opts {
  70		o(&options)
  71	}
  72
  73	if options.baseURL == "" {
  74		options.baseURL = "https://api.openai.com/v1"
  75	}
  76
  77	if options.name == "" {
  78		options.name = "openai"
  79	}
  80
  81	if options.organization != "" {
  82		options.headers["OpenAi-Organization"] = options.organization
  83	}
  84
  85	if options.project != "" {
  86		options.headers["OpenAi-Project"] = options.project
  87	}
  88
  89	return &provider{
  90		options: options,
  91	}
  92}
  93
  94func WithBaseURL(baseURL string) Option {
  95	return func(o *options) {
  96		o.baseURL = baseURL
  97	}
  98}
  99
 100func WithAPIKey(apiKey string) Option {
 101	return func(o *options) {
 102		o.apiKey = apiKey
 103	}
 104}
 105
 106func WithOrganization(organization string) Option {
 107	return func(o *options) {
 108		o.organization = organization
 109	}
 110}
 111
 112func WithProject(project string) Option {
 113	return func(o *options) {
 114		o.project = project
 115	}
 116}
 117
 118func WithName(name string) Option {
 119	return func(o *options) {
 120		o.name = name
 121	}
 122}
 123
 124func WithHeaders(headers map[string]string) Option {
 125	return func(o *options) {
 126		maps.Copy(o.headers, headers)
 127	}
 128}
 129
 130func WithHTTPClient(client option.HTTPClient) Option {
 131	return func(o *options) {
 132		o.client = client
 133	}
 134}
 135
 136// LanguageModel implements ai.Provider.
 137func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 138	openaiClientOptions := []option.RequestOption{}
 139	if o.options.apiKey != "" {
 140		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 141	}
 142	if o.options.baseURL != "" {
 143		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 144	}
 145
 146	for key, value := range o.options.headers {
 147		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 148	}
 149
 150	if o.options.client != nil {
 151		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 152	}
 153
 154	return languageModel{
 155		modelID:  modelID,
 156		provider: fmt.Sprintf("%s.chat", o.options.name),
 157		options:  o.options,
 158		client:   openai.NewClient(openaiClientOptions...),
 159	}, nil
 160}
 161
 162type languageModel struct {
 163	provider string
 164	modelID  string
 165	client   openai.Client
 166	options  options
 167}
 168
 169// Model implements ai.LanguageModel.
 170func (o languageModel) Model() string {
 171	return o.modelID
 172}
 173
 174// Provider implements ai.LanguageModel.
 175func (o languageModel) Provider() string {
 176	return o.provider
 177}
 178
 179func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 180	params := &openai.ChatCompletionNewParams{}
 181	messages, warnings := toPrompt(call.Prompt)
 182	providerOptions := &ProviderOptions{}
 183	if v, ok := call.ProviderOptions["openai"]; ok {
 184		err := ai.ParseOptions(v, providerOptions)
 185		if err != nil {
 186			return nil, nil, err
 187		}
 188	}
 189	if call.TopK != nil {
 190		warnings = append(warnings, ai.CallWarning{
 191			Type:    ai.CallWarningTypeUnsupportedSetting,
 192			Setting: "top_k",
 193		})
 194	}
 195	params.Messages = messages
 196	params.Model = o.modelID
 197	if providerOptions.LogitBias != nil {
 198		params.LogitBias = providerOptions.LogitBias
 199	}
 200	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 201		providerOptions.LogProbs = nil
 202	}
 203	if providerOptions.LogProbs != nil {
 204		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 205	}
 206	if providerOptions.TopLogProbs != nil {
 207		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 208	}
 209	if providerOptions.User != nil {
 210		params.User = param.NewOpt(*providerOptions.User)
 211	}
 212	if providerOptions.ParallelToolCalls != nil {
 213		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 214	}
 215
 216	if call.MaxOutputTokens != nil {
 217		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 218	}
 219	if call.Temperature != nil {
 220		params.Temperature = param.NewOpt(*call.Temperature)
 221	}
 222	if call.TopP != nil {
 223		params.TopP = param.NewOpt(*call.TopP)
 224	}
 225	if call.FrequencyPenalty != nil {
 226		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 227	}
 228	if call.PresencePenalty != nil {
 229		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 230	}
 231
 232	if providerOptions.MaxCompletionTokens != nil {
 233		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 234	}
 235
 236	if providerOptions.TextVerbosity != nil {
 237		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 238	}
 239	if providerOptions.Prediction != nil {
 240		// Convert map[string]any to ChatCompletionPredictionContentParam
 241		if content, ok := providerOptions.Prediction["content"]; ok {
 242			if contentStr, ok := content.(string); ok {
 243				params.Prediction = openai.ChatCompletionPredictionContentParam{
 244					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 245						OfString: param.NewOpt(contentStr),
 246					},
 247				}
 248			}
 249		}
 250	}
 251	if providerOptions.Store != nil {
 252		params.Store = param.NewOpt(*providerOptions.Store)
 253	}
 254	if providerOptions.Metadata != nil {
 255		// Convert map[string]any to map[string]string
 256		metadata := make(map[string]string)
 257		for k, v := range providerOptions.Metadata {
 258			if str, ok := v.(string); ok {
 259				metadata[k] = str
 260			}
 261		}
 262		params.Metadata = metadata
 263	}
 264	if providerOptions.PromptCacheKey != nil {
 265		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 266	}
 267	if providerOptions.SafetyIdentifier != nil {
 268		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 269	}
 270	if providerOptions.ServiceTier != nil {
 271		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 272	}
 273
 274	if providerOptions.ReasoningEffort != nil {
 275		switch *providerOptions.ReasoningEffort {
 276		case ReasoningEffortMinimal:
 277			params.ReasoningEffort = shared.ReasoningEffortMinimal
 278		case ReasoningEffortLow:
 279			params.ReasoningEffort = shared.ReasoningEffortLow
 280		case ReasoningEffortMedium:
 281			params.ReasoningEffort = shared.ReasoningEffortMedium
 282		case ReasoningEffortHigh:
 283			params.ReasoningEffort = shared.ReasoningEffortHigh
 284		default:
 285			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 286		}
 287	}
 288
 289	if isReasoningModel(o.modelID) {
 290		// remove unsupported settings for reasoning models
 291		// see https://platform.openai.com/docs/guides/reasoning#limitations
 292		if call.Temperature != nil {
 293			params.Temperature = param.Opt[float64]{}
 294			warnings = append(warnings, ai.CallWarning{
 295				Type:    ai.CallWarningTypeUnsupportedSetting,
 296				Setting: "temperature",
 297				Details: "temperature is not supported for reasoning models",
 298			})
 299		}
 300		if call.TopP != nil {
 301			params.TopP = param.Opt[float64]{}
 302			warnings = append(warnings, ai.CallWarning{
 303				Type:    ai.CallWarningTypeUnsupportedSetting,
 304				Setting: "TopP",
 305				Details: "TopP is not supported for reasoning models",
 306			})
 307		}
 308		if call.FrequencyPenalty != nil {
 309			params.FrequencyPenalty = param.Opt[float64]{}
 310			warnings = append(warnings, ai.CallWarning{
 311				Type:    ai.CallWarningTypeUnsupportedSetting,
 312				Setting: "FrequencyPenalty",
 313				Details: "FrequencyPenalty is not supported for reasoning models",
 314			})
 315		}
 316		if call.PresencePenalty != nil {
 317			params.PresencePenalty = param.Opt[float64]{}
 318			warnings = append(warnings, ai.CallWarning{
 319				Type:    ai.CallWarningTypeUnsupportedSetting,
 320				Setting: "PresencePenalty",
 321				Details: "PresencePenalty is not supported for reasoning models",
 322			})
 323		}
 324		if providerOptions.LogitBias != nil {
 325			params.LogitBias = nil
 326			warnings = append(warnings, ai.CallWarning{
 327				Type:    ai.CallWarningTypeUnsupportedSetting,
 328				Setting: "LogitBias",
 329				Message: "LogitBias is not supported for reasoning models",
 330			})
 331		}
 332		if providerOptions.LogProbs != nil {
 333			params.Logprobs = param.Opt[bool]{}
 334			warnings = append(warnings, ai.CallWarning{
 335				Type:    ai.CallWarningTypeUnsupportedSetting,
 336				Setting: "Logprobs",
 337				Message: "Logprobs is not supported for reasoning models",
 338			})
 339		}
 340		if providerOptions.TopLogProbs != nil {
 341			params.TopLogprobs = param.Opt[int64]{}
 342			warnings = append(warnings, ai.CallWarning{
 343				Type:    ai.CallWarningTypeUnsupportedSetting,
 344				Setting: "TopLogprobs",
 345				Message: "TopLogprobs is not supported for reasoning models",
 346			})
 347		}
 348
 349		// reasoning models use max_completion_tokens instead of max_tokens
 350		if call.MaxOutputTokens != nil {
 351			if providerOptions.MaxCompletionTokens == nil {
 352				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 353			}
 354			params.MaxTokens = param.Opt[int64]{}
 355		}
 356	}
 357
 358	// Handle search preview models
 359	if isSearchPreviewModel(o.modelID) {
 360		if call.Temperature != nil {
 361			params.Temperature = param.Opt[float64]{}
 362			warnings = append(warnings, ai.CallWarning{
 363				Type:    ai.CallWarningTypeUnsupportedSetting,
 364				Setting: "temperature",
 365				Details: "temperature is not supported for the search preview models and has been removed.",
 366			})
 367		}
 368	}
 369
 370	// Handle service tier validation
 371	if providerOptions.ServiceTier != nil {
 372		serviceTier := *providerOptions.ServiceTier
 373		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 374			params.ServiceTier = ""
 375			warnings = append(warnings, ai.CallWarning{
 376				Type:    ai.CallWarningTypeUnsupportedSetting,
 377				Setting: "ServiceTier",
 378				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 379			})
 380		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 381			params.ServiceTier = ""
 382			warnings = append(warnings, ai.CallWarning{
 383				Type:    ai.CallWarningTypeUnsupportedSetting,
 384				Setting: "ServiceTier",
 385				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 386			})
 387		}
 388	}
 389
 390	if len(call.Tools) > 0 {
 391		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 392		params.Tools = tools
 393		if toolChoice != nil {
 394			params.ToolChoice = *toolChoice
 395		}
 396		warnings = append(warnings, toolWarnings...)
 397	}
 398	return params, warnings, nil
 399}
 400
 401func (o languageModel) handleError(err error) error {
 402	var apiErr *openai.Error
 403	if errors.As(err, &apiErr) {
 404		requestDump := apiErr.DumpRequest(true)
 405		responseDump := apiErr.DumpResponse(true)
 406		headers := map[string]string{}
 407		for k, h := range apiErr.Response.Header {
 408			v := h[len(h)-1]
 409			headers[strings.ToLower(k)] = v
 410		}
 411		return ai.NewAPICallError(
 412			apiErr.Message,
 413			apiErr.Request.URL.String(),
 414			string(requestDump),
 415			apiErr.StatusCode,
 416			headers,
 417			string(responseDump),
 418			apiErr,
 419			false,
 420		)
 421	}
 422	return err
 423}
 424
 425// Generate implements ai.LanguageModel.
 426func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 427	params, warnings, err := o.prepareParams(call)
 428	if err != nil {
 429		return nil, err
 430	}
 431	response, err := o.client.Chat.Completions.New(ctx, *params)
 432	if err != nil {
 433		return nil, o.handleError(err)
 434	}
 435
 436	if len(response.Choices) == 0 {
 437		return nil, errors.New("no response generated")
 438	}
 439	choice := response.Choices[0]
 440	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 441	text := choice.Message.Content
 442	if text != "" {
 443		content = append(content, ai.TextContent{
 444			Text: text,
 445		})
 446	}
 447
 448	for _, tc := range choice.Message.ToolCalls {
 449		toolCallID := tc.ID
 450		if toolCallID == "" {
 451			toolCallID = uuid.NewString()
 452		}
 453		content = append(content, ai.ToolCallContent{
 454			ProviderExecuted: false, // TODO: update when handling other tools
 455			ToolCallID:       toolCallID,
 456			ToolName:         tc.Function.Name,
 457			Input:            tc.Function.Arguments,
 458		})
 459	}
 460	// Handle annotations/citations
 461	for _, annotation := range choice.Message.Annotations {
 462		if annotation.Type == "url_citation" {
 463			content = append(content, ai.SourceContent{
 464				SourceType: ai.SourceTypeURL,
 465				ID:         uuid.NewString(),
 466				URL:        annotation.URLCitation.URL,
 467				Title:      annotation.URLCitation.Title,
 468			})
 469		}
 470	}
 471
 472	completionTokenDetails := response.Usage.CompletionTokensDetails
 473	promptTokenDetails := response.Usage.PromptTokensDetails
 474
 475	// Build provider metadata
 476	providerMetadata := ai.ProviderMetadata{
 477		"openai": make(map[string]any),
 478	}
 479
 480	// Add logprobs if available
 481	if len(choice.Logprobs.Content) > 0 {
 482		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
 483	}
 484
 485	// Add prediction tokens if available
 486	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 487		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 488			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 489		}
 490		if completionTokenDetails.RejectedPredictionTokens > 0 {
 491			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 492		}
 493	}
 494
 495	return &ai.Response{
 496		Content: content,
 497		Usage: ai.Usage{
 498			InputTokens:     response.Usage.PromptTokens,
 499			OutputTokens:    response.Usage.CompletionTokens,
 500			TotalTokens:     response.Usage.TotalTokens,
 501			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 502			CacheReadTokens: promptTokenDetails.CachedTokens,
 503		},
 504		FinishReason:     mapOpenAiFinishReason(choice.FinishReason),
 505		ProviderMetadata: providerMetadata,
 506		Warnings:         warnings,
 507	}, nil
 508}
 509
 510type toolCall struct {
 511	id          string
 512	name        string
 513	arguments   string
 514	hasFinished bool
 515}
 516
 517// Stream implements ai.LanguageModel.
 518func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 519	params, warnings, err := o.prepareParams(call)
 520	if err != nil {
 521		return nil, err
 522	}
 523
 524	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 525		IncludeUsage: openai.Bool(true),
 526	}
 527
 528	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 529	isActiveText := false
 530	toolCalls := make(map[int64]toolCall)
 531
 532	// Build provider metadata for streaming
 533	streamProviderMetadata := ai.ProviderMetadata{
 534		"openai": make(map[string]any),
 535	}
 536
 537	acc := openai.ChatCompletionAccumulator{}
 538	var usage ai.Usage
 539	return func(yield func(ai.StreamPart) bool) {
 540		if len(warnings) > 0 {
 541			if !yield(ai.StreamPart{
 542				Type:     ai.StreamPartTypeWarnings,
 543				Warnings: warnings,
 544			}) {
 545				return
 546			}
 547		}
 548		for stream.Next() {
 549			chunk := stream.Current()
 550			acc.AddChunk(chunk)
 551			if chunk.Usage.TotalTokens > 0 {
 552				// we do this here because the acc does not add prompt details
 553				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 554				promptTokenDetails := chunk.Usage.PromptTokensDetails
 555				usage = ai.Usage{
 556					InputTokens:     chunk.Usage.PromptTokens,
 557					OutputTokens:    chunk.Usage.CompletionTokens,
 558					TotalTokens:     chunk.Usage.TotalTokens,
 559					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 560					CacheReadTokens: promptTokenDetails.CachedTokens,
 561				}
 562
 563				// Add prediction tokens if available
 564				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 565					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 566						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 567					}
 568					if completionTokenDetails.RejectedPredictionTokens > 0 {
 569						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 570					}
 571				}
 572			}
 573			if len(chunk.Choices) == 0 {
 574				continue
 575			}
 576			for _, choice := range chunk.Choices {
 577				switch {
 578				case choice.Delta.Content != "":
 579					if !isActiveText {
 580						isActiveText = true
 581						if !yield(ai.StreamPart{
 582							Type: ai.StreamPartTypeTextStart,
 583							ID:   "0",
 584						}) {
 585							return
 586						}
 587					}
 588					if !yield(ai.StreamPart{
 589						Type:  ai.StreamPartTypeTextDelta,
 590						ID:    "0",
 591						Delta: choice.Delta.Content,
 592					}) {
 593						return
 594					}
 595				case len(choice.Delta.ToolCalls) > 0:
 596					if isActiveText {
 597						isActiveText = false
 598						if !yield(ai.StreamPart{
 599							Type: ai.StreamPartTypeTextEnd,
 600							ID:   "0",
 601						}) {
 602							return
 603						}
 604					}
 605
 606					for _, toolCallDelta := range choice.Delta.ToolCalls {
 607						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 608							if existingToolCall.hasFinished {
 609								continue
 610							}
 611							if toolCallDelta.Function.Arguments != "" {
 612								existingToolCall.arguments += toolCallDelta.Function.Arguments
 613							}
 614							if !yield(ai.StreamPart{
 615								Type:  ai.StreamPartTypeToolInputDelta,
 616								ID:    existingToolCall.id,
 617								Delta: toolCallDelta.Function.Arguments,
 618							}) {
 619								return
 620							}
 621							toolCalls[toolCallDelta.Index] = existingToolCall
 622							if jsonext.IsValidJSON(existingToolCall.arguments) {
 623								if !yield(ai.StreamPart{
 624									Type: ai.StreamPartTypeToolInputEnd,
 625									ID:   existingToolCall.id,
 626								}) {
 627									return
 628								}
 629
 630								if !yield(ai.StreamPart{
 631									Type:          ai.StreamPartTypeToolCall,
 632									ID:            existingToolCall.id,
 633									ToolCallName:  existingToolCall.name,
 634									ToolCallInput: existingToolCall.arguments,
 635								}) {
 636									return
 637								}
 638								existingToolCall.hasFinished = true
 639								toolCalls[toolCallDelta.Index] = existingToolCall
 640							}
 641						} else {
 642							// Does not exist
 643							var err error
 644							if toolCallDelta.Type != "function" {
 645								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 646							}
 647							if toolCallDelta.ID == "" {
 648								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 649							}
 650							if toolCallDelta.Function.Name == "" {
 651								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 652							}
 653							if err != nil {
 654								yield(ai.StreamPart{
 655									Type:  ai.StreamPartTypeError,
 656									Error: o.handleError(stream.Err()),
 657								})
 658								return
 659							}
 660
 661							if !yield(ai.StreamPart{
 662								Type:         ai.StreamPartTypeToolInputStart,
 663								ID:           toolCallDelta.ID,
 664								ToolCallName: toolCallDelta.Function.Name,
 665							}) {
 666								return
 667							}
 668							toolCalls[toolCallDelta.Index] = toolCall{
 669								id:        toolCallDelta.ID,
 670								name:      toolCallDelta.Function.Name,
 671								arguments: toolCallDelta.Function.Arguments,
 672							}
 673
 674							exTc := toolCalls[toolCallDelta.Index]
 675							if exTc.arguments != "" {
 676								if !yield(ai.StreamPart{
 677									Type:  ai.StreamPartTypeToolInputDelta,
 678									ID:    exTc.id,
 679									Delta: exTc.arguments,
 680								}) {
 681									return
 682								}
 683								if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) {
 684									if !yield(ai.StreamPart{
 685										Type: ai.StreamPartTypeToolInputEnd,
 686										ID:   toolCallDelta.ID,
 687									}) {
 688										return
 689									}
 690
 691									if !yield(ai.StreamPart{
 692										Type:          ai.StreamPartTypeToolCall,
 693										ID:            exTc.id,
 694										ToolCallName:  exTc.name,
 695										ToolCallInput: exTc.arguments,
 696									}) {
 697										return
 698									}
 699									exTc.hasFinished = true
 700									toolCalls[toolCallDelta.Index] = exTc
 701								}
 702							}
 703							continue
 704						}
 705					}
 706				}
 707			}
 708
 709			// Check for annotations in the delta's raw JSON
 710			for _, choice := range chunk.Choices {
 711				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 712					for _, annotation := range annotations {
 713						if annotation.Type == "url_citation" {
 714							if !yield(ai.StreamPart{
 715								Type:       ai.StreamPartTypeSource,
 716								ID:         uuid.NewString(),
 717								SourceType: ai.SourceTypeURL,
 718								URL:        annotation.URLCitation.URL,
 719								Title:      annotation.URLCitation.Title,
 720							}) {
 721								return
 722							}
 723						}
 724					}
 725				}
 726			}
 727		}
 728		err := stream.Err()
 729		if err == nil || errors.Is(err, io.EOF) {
 730			// finished
 731			if isActiveText {
 732				isActiveText = false
 733				if !yield(ai.StreamPart{
 734					Type: ai.StreamPartTypeTextEnd,
 735					ID:   "0",
 736				}) {
 737					return
 738				}
 739			}
 740
 741			// Add logprobs if available
 742			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 743				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
 744			}
 745
 746			// Handle annotations/citations from accumulated response
 747			if len(acc.Choices) > 0 {
 748				for _, annotation := range acc.Choices[0].Message.Annotations {
 749					if annotation.Type == "url_citation" {
 750						if !yield(ai.StreamPart{
 751							Type:       ai.StreamPartTypeSource,
 752							ID:         acc.ID,
 753							SourceType: ai.SourceTypeURL,
 754							URL:        annotation.URLCitation.URL,
 755							Title:      annotation.URLCitation.Title,
 756						}) {
 757							return
 758						}
 759					}
 760				}
 761			}
 762
 763			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 764			yield(ai.StreamPart{
 765				Type:             ai.StreamPartTypeFinish,
 766				Usage:            usage,
 767				FinishReason:     finishReason,
 768				ProviderMetadata: streamProviderMetadata,
 769			})
 770			return
 771		} else {
 772			yield(ai.StreamPart{
 773				Type:  ai.StreamPartTypeError,
 774				Error: o.handleError(err),
 775			})
 776			return
 777		}
 778	}, nil
 779}
 780
 781func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 782	switch finishReason {
 783	case "stop":
 784		return ai.FinishReasonStop
 785	case "length":
 786		return ai.FinishReasonLength
 787	case "content_filter":
 788		return ai.FinishReasonContentFilter
 789	case "function_call", "tool_calls":
 790		return ai.FinishReasonToolCalls
 791	default:
 792		return ai.FinishReasonUnknown
 793	}
 794}
 795
 796func isReasoningModel(modelID string) bool {
 797	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 798}
 799
 800func isSearchPreviewModel(modelID string) bool {
 801	return strings.Contains(modelID, "search-preview")
 802}
 803
 804func supportsFlexProcessing(modelID string) bool {
 805	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 806}
 807
 808func supportsPriorityProcessing(modelID string) bool {
 809	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 810		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 811		strings.HasPrefix(modelID, "o4-mini")
 812}
 813
 814func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 815	for _, tool := range tools {
 816		if tool.GetType() == ai.ToolTypeFunction {
 817			ft, ok := tool.(ai.FunctionTool)
 818			if !ok {
 819				continue
 820			}
 821			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 822				OfFunction: &openai.ChatCompletionFunctionToolParam{
 823					Function: shared.FunctionDefinitionParam{
 824						Name:        ft.Name,
 825						Description: param.NewOpt(ft.Description),
 826						Parameters:  openai.FunctionParameters(ft.InputSchema),
 827						Strict:      param.NewOpt(false),
 828					},
 829					Type: "function",
 830				},
 831			})
 832			continue
 833		}
 834
 835		// TODO: handle provider tool calls
 836		warnings = append(warnings, ai.CallWarning{
 837			Type:    ai.CallWarningTypeUnsupportedTool,
 838			Tool:    tool,
 839			Message: "tool is not supported",
 840		})
 841	}
 842	if toolChoice == nil {
 843		return openAiTools, openAiToolChoice, warnings
 844	}
 845
 846	switch *toolChoice {
 847	case ai.ToolChoiceAuto:
 848		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 849			OfAuto: param.NewOpt("auto"),
 850		}
 851	case ai.ToolChoiceNone:
 852		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 853			OfAuto: param.NewOpt("none"),
 854		}
 855	default:
 856		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 857			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 858				Type: "function",
 859				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 860					Name: string(*toolChoice),
 861				},
 862			},
 863		}
 864	}
 865	return openAiTools, openAiToolChoice, warnings
 866}
 867
 868func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 869	var messages []openai.ChatCompletionMessageParamUnion
 870	var warnings []ai.CallWarning
 871	for _, msg := range prompt {
 872		switch msg.Role {
 873		case ai.MessageRoleSystem:
 874			var systemPromptParts []string
 875			for _, c := range msg.Content {
 876				if c.GetType() != ai.ContentTypeText {
 877					warnings = append(warnings, ai.CallWarning{
 878						Type:    ai.CallWarningTypeOther,
 879						Message: "system prompt can only have text content",
 880					})
 881					continue
 882				}
 883				textPart, ok := ai.AsContentType[ai.TextPart](c)
 884				if !ok {
 885					warnings = append(warnings, ai.CallWarning{
 886						Type:    ai.CallWarningTypeOther,
 887						Message: "system prompt text part does not have the right type",
 888					})
 889					continue
 890				}
 891				text := textPart.Text
 892				if strings.TrimSpace(text) != "" {
 893					systemPromptParts = append(systemPromptParts, textPart.Text)
 894				}
 895			}
 896			if len(systemPromptParts) == 0 {
 897				warnings = append(warnings, ai.CallWarning{
 898					Type:    ai.CallWarningTypeOther,
 899					Message: "system prompt has no text parts",
 900				})
 901				continue
 902			}
 903			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 904		case ai.MessageRoleUser:
 905			// simple user message just text content
 906			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 907				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 908				if !ok {
 909					warnings = append(warnings, ai.CallWarning{
 910						Type:    ai.CallWarningTypeOther,
 911						Message: "user message text part does not have the right type",
 912					})
 913					continue
 914				}
 915				messages = append(messages, openai.UserMessage(textPart.Text))
 916				continue
 917			}
 918			// text content and attachments
 919			// for now we only support image content later we need to check
 920			// TODO: add the supported media types to the language model so we
 921			//  can use that to validate the data here.
 922			var content []openai.ChatCompletionContentPartUnionParam
 923			for _, c := range msg.Content {
 924				switch c.GetType() {
 925				case ai.ContentTypeText:
 926					textPart, ok := ai.AsContentType[ai.TextPart](c)
 927					if !ok {
 928						warnings = append(warnings, ai.CallWarning{
 929							Type:    ai.CallWarningTypeOther,
 930							Message: "user message text part does not have the right type",
 931						})
 932						continue
 933					}
 934					content = append(content, openai.ChatCompletionContentPartUnionParam{
 935						OfText: &openai.ChatCompletionContentPartTextParam{
 936							Text: textPart.Text,
 937						},
 938					})
 939				case ai.ContentTypeFile:
 940					filePart, ok := ai.AsContentType[ai.FilePart](c)
 941					if !ok {
 942						warnings = append(warnings, ai.CallWarning{
 943							Type:    ai.CallWarningTypeOther,
 944							Message: "user message file part does not have the right type",
 945						})
 946						continue
 947					}
 948
 949					switch {
 950					case strings.HasPrefix(filePart.MediaType, "image/"):
 951						// Handle image files
 952						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 953						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 954						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 955
 956						// Check for provider-specific options like image detail
 957						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 958							if detail, ok := providerOptions["imageDetail"].(string); ok {
 959								imageURL.Detail = detail
 960							}
 961						}
 962
 963						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 964						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 965
 966					case filePart.MediaType == "audio/wav":
 967						// Handle WAV audio files
 968						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 969						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 970							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 971								Data:   base64Encoded,
 972								Format: "wav",
 973							},
 974						}
 975						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 976
 977					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 978						// Handle MP3 audio files
 979						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 980						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 981							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 982								Data:   base64Encoded,
 983								Format: "mp3",
 984							},
 985						}
 986						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 987
 988					case filePart.MediaType == "application/pdf":
 989						// Handle PDF files
 990						dataStr := string(filePart.Data)
 991
 992						// Check if data looks like a file ID (starts with "file-")
 993						if strings.HasPrefix(dataStr, "file-") {
 994							fileBlock := openai.ChatCompletionContentPartFileParam{
 995								File: openai.ChatCompletionContentPartFileFileParam{
 996									FileID: param.NewOpt(dataStr),
 997								},
 998							}
 999							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1000						} else {
1001							// Handle as base64 data
1002							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1003							data := "data:application/pdf;base64," + base64Encoded
1004
1005							filename := filePart.Filename
1006							if filename == "" {
1007								// Generate default filename based on content index
1008								filename = fmt.Sprintf("part-%d.pdf", len(content))
1009							}
1010
1011							fileBlock := openai.ChatCompletionContentPartFileParam{
1012								File: openai.ChatCompletionContentPartFileFileParam{
1013									Filename: param.NewOpt(filename),
1014									FileData: param.NewOpt(data),
1015								},
1016							}
1017							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1018						}
1019
1020					default:
1021						warnings = append(warnings, ai.CallWarning{
1022							Type:    ai.CallWarningTypeOther,
1023							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1024						})
1025					}
1026				}
1027			}
1028			messages = append(messages, openai.UserMessage(content))
1029		case ai.MessageRoleAssistant:
1030			// simple assistant message just text content
1031			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1032				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1033				if !ok {
1034					warnings = append(warnings, ai.CallWarning{
1035						Type:    ai.CallWarningTypeOther,
1036						Message: "assistant message text part does not have the right type",
1037					})
1038					continue
1039				}
1040				messages = append(messages, openai.AssistantMessage(textPart.Text))
1041				continue
1042			}
1043			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1044				Role: "assistant",
1045			}
1046			for _, c := range msg.Content {
1047				switch c.GetType() {
1048				case ai.ContentTypeText:
1049					textPart, ok := ai.AsContentType[ai.TextPart](c)
1050					if !ok {
1051						warnings = append(warnings, ai.CallWarning{
1052							Type:    ai.CallWarningTypeOther,
1053							Message: "assistant message text part does not have the right type",
1054						})
1055						continue
1056					}
1057					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1058						OfString: param.NewOpt(textPart.Text),
1059					}
1060				case ai.ContentTypeToolCall:
1061					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1062					if !ok {
1063						warnings = append(warnings, ai.CallWarning{
1064							Type:    ai.CallWarningTypeOther,
1065							Message: "assistant message tool part does not have the right type",
1066						})
1067						continue
1068					}
1069					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1070						openai.ChatCompletionMessageToolCallUnionParam{
1071							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1072								ID:   toolCallPart.ToolCallID,
1073								Type: "function",
1074								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1075									Name:      toolCallPart.ToolName,
1076									Arguments: toolCallPart.Input,
1077								},
1078							},
1079						})
1080				}
1081			}
1082			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1083				OfAssistant: &assistantMsg,
1084			})
1085		case ai.MessageRoleTool:
1086			for _, c := range msg.Content {
1087				if c.GetType() != ai.ContentTypeToolResult {
1088					warnings = append(warnings, ai.CallWarning{
1089						Type:    ai.CallWarningTypeOther,
1090						Message: "tool message can only have tool result content",
1091					})
1092					continue
1093				}
1094
1095				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1096				if !ok {
1097					warnings = append(warnings, ai.CallWarning{
1098						Type:    ai.CallWarningTypeOther,
1099						Message: "tool message result part does not have the right type",
1100					})
1101					continue
1102				}
1103
1104				switch toolResultPart.Output.GetType() {
1105				case ai.ToolResultContentTypeText:
1106					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1107					if !ok {
1108						warnings = append(warnings, ai.CallWarning{
1109							Type:    ai.CallWarningTypeOther,
1110							Message: "tool result output does not have the right type",
1111						})
1112						continue
1113					}
1114					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1115				case ai.ToolResultContentTypeError:
1116					// TODO: check if better handling is needed
1117					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1118					if !ok {
1119						warnings = append(warnings, ai.CallWarning{
1120							Type:    ai.CallWarningTypeOther,
1121							Message: "tool result output does not have the right type",
1122						})
1123						continue
1124					}
1125					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1126				}
1127			}
1128		}
1129	}
1130	return messages, warnings
1131}
1132
1133// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1134func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1135	var annotations []openai.ChatCompletionMessageAnnotation
1136
1137	// Parse the raw JSON to extract annotations
1138	var deltaData map[string]any
1139	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1140		return annotations
1141	}
1142
1143	// Check if annotations exist in the delta
1144	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1145		for _, annotationData := range annotationsData {
1146			if annotationMap, ok := annotationData.(map[string]any); ok {
1147				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1148					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1149						annotation := openai.ChatCompletionMessageAnnotation{
1150							Type: "url_citation",
1151							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1152								URL:   urlCitationData["url"].(string),
1153								Title: urlCitationData["title"].(string),
1154							},
1155						}
1156						annotations = append(annotations, annotation)
1157					}
1158				}
1159			}
1160		}
1161	}
1162
1163	return annotations
1164}