openai.go

   1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/ai/ai"
  15	xjson "github.com/charmbracelet/x/json"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23const (
  24	ProviderName = "openai"
  25	DefaultURL   = "https://api.openai.com/v1"
  26)
  27
  28type provider struct {
  29	options options
  30}
  31
  32type options struct {
  33	baseURL      string
  34	apiKey       string
  35	organization string
  36	project      string
  37	name         string
  38	headers      map[string]string
  39	client       option.HTTPClient
  40}
  41
  42type Option = func(*options)
  43
  44func New(opts ...Option) ai.Provider {
  45	providerOptions := options{
  46		headers: map[string]string{},
  47	}
  48	for _, o := range opts {
  49		o(&providerOptions)
  50	}
  51
  52	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
  53	providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
  54
  55	if providerOptions.organization != "" {
  56		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
  57	}
  58	if providerOptions.project != "" {
  59		providerOptions.headers["OpenAi-Project"] = providerOptions.project
  60	}
  61
  62	return &provider{options: providerOptions}
  63}
  64
  65func WithBaseURL(baseURL string) Option {
  66	return func(o *options) {
  67		o.baseURL = baseURL
  68	}
  69}
  70
  71func WithAPIKey(apiKey string) Option {
  72	return func(o *options) {
  73		o.apiKey = apiKey
  74	}
  75}
  76
  77func WithOrganization(organization string) Option {
  78	return func(o *options) {
  79		o.organization = organization
  80	}
  81}
  82
  83func WithProject(project string) Option {
  84	return func(o *options) {
  85		o.project = project
  86	}
  87}
  88
  89func WithName(name string) Option {
  90	return func(o *options) {
  91		o.name = name
  92	}
  93}
  94
  95func WithHeaders(headers map[string]string) Option {
  96	return func(o *options) {
  97		maps.Copy(o.headers, headers)
  98	}
  99}
 100
 101func WithHTTPClient(client option.HTTPClient) Option {
 102	return func(o *options) {
 103		o.client = client
 104	}
 105}
 106
 107// LanguageModel implements ai.Provider.
 108func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 109	openaiClientOptions := []option.RequestOption{}
 110	if o.options.apiKey != "" {
 111		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 112	}
 113	if o.options.baseURL != "" {
 114		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 115	}
 116
 117	for key, value := range o.options.headers {
 118		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 119	}
 120
 121	if o.options.client != nil {
 122		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 123	}
 124
 125	return languageModel{
 126		modelID:  modelID,
 127		provider: fmt.Sprintf("%s.chat", o.options.name),
 128		options:  o.options,
 129		client:   openai.NewClient(openaiClientOptions...),
 130	}, nil
 131}
 132
 133type languageModel struct {
 134	provider string
 135	modelID  string
 136	client   openai.Client
 137	options  options
 138}
 139
 140// Model implements ai.LanguageModel.
 141func (o languageModel) Model() string {
 142	return o.modelID
 143}
 144
 145// Provider implements ai.LanguageModel.
 146func (o languageModel) Provider() string {
 147	return o.provider
 148}
 149
 150func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 151	params := &openai.ChatCompletionNewParams{}
 152	messages, warnings := toPrompt(call.Prompt)
 153	providerOptions := &ProviderOptions{}
 154	if v, ok := call.ProviderOptions[ProviderOptionsKey]; ok {
 155		providerOptions, ok = v.(*ProviderOptions)
 156		if !ok {
 157			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 158		}
 159	}
 160	if call.TopK != nil {
 161		warnings = append(warnings, ai.CallWarning{
 162			Type:    ai.CallWarningTypeUnsupportedSetting,
 163			Setting: "top_k",
 164		})
 165	}
 166	params.Messages = messages
 167	params.Model = o.modelID
 168	if providerOptions.LogitBias != nil {
 169		params.LogitBias = providerOptions.LogitBias
 170	}
 171	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 172		providerOptions.LogProbs = nil
 173	}
 174	if providerOptions.LogProbs != nil {
 175		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 176	}
 177	if providerOptions.TopLogProbs != nil {
 178		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 179	}
 180	if providerOptions.User != nil {
 181		params.User = param.NewOpt(*providerOptions.User)
 182	}
 183	if providerOptions.ParallelToolCalls != nil {
 184		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 185	}
 186
 187	if call.MaxOutputTokens != nil {
 188		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 189	}
 190	if call.Temperature != nil {
 191		params.Temperature = param.NewOpt(*call.Temperature)
 192	}
 193	if call.TopP != nil {
 194		params.TopP = param.NewOpt(*call.TopP)
 195	}
 196	if call.FrequencyPenalty != nil {
 197		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 198	}
 199	if call.PresencePenalty != nil {
 200		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 201	}
 202
 203	if providerOptions.MaxCompletionTokens != nil {
 204		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 205	}
 206
 207	if providerOptions.TextVerbosity != nil {
 208		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 209	}
 210	if providerOptions.Prediction != nil {
 211		// Convert map[string]any to ChatCompletionPredictionContentParam
 212		if content, ok := providerOptions.Prediction["content"]; ok {
 213			if contentStr, ok := content.(string); ok {
 214				params.Prediction = openai.ChatCompletionPredictionContentParam{
 215					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 216						OfString: param.NewOpt(contentStr),
 217					},
 218				}
 219			}
 220		}
 221	}
 222	if providerOptions.Store != nil {
 223		params.Store = param.NewOpt(*providerOptions.Store)
 224	}
 225	if providerOptions.Metadata != nil {
 226		// Convert map[string]any to map[string]string
 227		metadata := make(map[string]string)
 228		for k, v := range providerOptions.Metadata {
 229			if str, ok := v.(string); ok {
 230				metadata[k] = str
 231			}
 232		}
 233		params.Metadata = metadata
 234	}
 235	if providerOptions.PromptCacheKey != nil {
 236		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 237	}
 238	if providerOptions.SafetyIdentifier != nil {
 239		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 240	}
 241	if providerOptions.ServiceTier != nil {
 242		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 243	}
 244
 245	if providerOptions.ReasoningEffort != nil {
 246		switch *providerOptions.ReasoningEffort {
 247		case ReasoningEffortMinimal:
 248			params.ReasoningEffort = shared.ReasoningEffortMinimal
 249		case ReasoningEffortLow:
 250			params.ReasoningEffort = shared.ReasoningEffortLow
 251		case ReasoningEffortMedium:
 252			params.ReasoningEffort = shared.ReasoningEffortMedium
 253		case ReasoningEffortHigh:
 254			params.ReasoningEffort = shared.ReasoningEffortHigh
 255		default:
 256			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 257		}
 258	}
 259
 260	if isReasoningModel(o.modelID) {
 261		// remove unsupported settings for reasoning models
 262		// see https://platform.openai.com/docs/guides/reasoning#limitations
 263		if call.Temperature != nil {
 264			params.Temperature = param.Opt[float64]{}
 265			warnings = append(warnings, ai.CallWarning{
 266				Type:    ai.CallWarningTypeUnsupportedSetting,
 267				Setting: "temperature",
 268				Details: "temperature is not supported for reasoning models",
 269			})
 270		}
 271		if call.TopP != nil {
 272			params.TopP = param.Opt[float64]{}
 273			warnings = append(warnings, ai.CallWarning{
 274				Type:    ai.CallWarningTypeUnsupportedSetting,
 275				Setting: "TopP",
 276				Details: "TopP is not supported for reasoning models",
 277			})
 278		}
 279		if call.FrequencyPenalty != nil {
 280			params.FrequencyPenalty = param.Opt[float64]{}
 281			warnings = append(warnings, ai.CallWarning{
 282				Type:    ai.CallWarningTypeUnsupportedSetting,
 283				Setting: "FrequencyPenalty",
 284				Details: "FrequencyPenalty is not supported for reasoning models",
 285			})
 286		}
 287		if call.PresencePenalty != nil {
 288			params.PresencePenalty = param.Opt[float64]{}
 289			warnings = append(warnings, ai.CallWarning{
 290				Type:    ai.CallWarningTypeUnsupportedSetting,
 291				Setting: "PresencePenalty",
 292				Details: "PresencePenalty is not supported for reasoning models",
 293			})
 294		}
 295		if providerOptions.LogitBias != nil {
 296			params.LogitBias = nil
 297			warnings = append(warnings, ai.CallWarning{
 298				Type:    ai.CallWarningTypeUnsupportedSetting,
 299				Setting: "LogitBias",
 300				Message: "LogitBias is not supported for reasoning models",
 301			})
 302		}
 303		if providerOptions.LogProbs != nil {
 304			params.Logprobs = param.Opt[bool]{}
 305			warnings = append(warnings, ai.CallWarning{
 306				Type:    ai.CallWarningTypeUnsupportedSetting,
 307				Setting: "Logprobs",
 308				Message: "Logprobs is not supported for reasoning models",
 309			})
 310		}
 311		if providerOptions.TopLogProbs != nil {
 312			params.TopLogprobs = param.Opt[int64]{}
 313			warnings = append(warnings, ai.CallWarning{
 314				Type:    ai.CallWarningTypeUnsupportedSetting,
 315				Setting: "TopLogprobs",
 316				Message: "TopLogprobs is not supported for reasoning models",
 317			})
 318		}
 319
 320		// reasoning models use max_completion_tokens instead of max_tokens
 321		if call.MaxOutputTokens != nil {
 322			if providerOptions.MaxCompletionTokens == nil {
 323				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 324			}
 325			params.MaxTokens = param.Opt[int64]{}
 326		}
 327	}
 328
 329	// Handle search preview models
 330	if isSearchPreviewModel(o.modelID) {
 331		if call.Temperature != nil {
 332			params.Temperature = param.Opt[float64]{}
 333			warnings = append(warnings, ai.CallWarning{
 334				Type:    ai.CallWarningTypeUnsupportedSetting,
 335				Setting: "temperature",
 336				Details: "temperature is not supported for the search preview models and has been removed.",
 337			})
 338		}
 339	}
 340
 341	// Handle service tier validation
 342	if providerOptions.ServiceTier != nil {
 343		serviceTier := *providerOptions.ServiceTier
 344		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 345			params.ServiceTier = ""
 346			warnings = append(warnings, ai.CallWarning{
 347				Type:    ai.CallWarningTypeUnsupportedSetting,
 348				Setting: "ServiceTier",
 349				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 350			})
 351		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 352			params.ServiceTier = ""
 353			warnings = append(warnings, ai.CallWarning{
 354				Type:    ai.CallWarningTypeUnsupportedSetting,
 355				Setting: "ServiceTier",
 356				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 357			})
 358		}
 359	}
 360
 361	if len(call.Tools) > 0 {
 362		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 363		params.Tools = tools
 364		if toolChoice != nil {
 365			params.ToolChoice = *toolChoice
 366		}
 367		warnings = append(warnings, toolWarnings...)
 368	}
 369	return params, warnings, nil
 370}
 371
 372func (o languageModel) handleError(err error) error {
 373	var apiErr *openai.Error
 374	if errors.As(err, &apiErr) {
 375		requestDump := apiErr.DumpRequest(true)
 376		responseDump := apiErr.DumpResponse(true)
 377		headers := map[string]string{}
 378		for k, h := range apiErr.Response.Header {
 379			v := h[len(h)-1]
 380			headers[strings.ToLower(k)] = v
 381		}
 382		return ai.NewAPICallError(
 383			apiErr.Message,
 384			apiErr.Request.URL.String(),
 385			string(requestDump),
 386			apiErr.StatusCode,
 387			headers,
 388			string(responseDump),
 389			apiErr,
 390			false,
 391		)
 392	}
 393	return err
 394}
 395
 396// Generate implements ai.LanguageModel.
 397func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 398	params, warnings, err := o.prepareParams(call)
 399	if err != nil {
 400		return nil, err
 401	}
 402	response, err := o.client.Chat.Completions.New(ctx, *params)
 403	if err != nil {
 404		return nil, o.handleError(err)
 405	}
 406
 407	if len(response.Choices) == 0 {
 408		return nil, errors.New("no response generated")
 409	}
 410	choice := response.Choices[0]
 411	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 412	text := choice.Message.Content
 413	if text != "" {
 414		content = append(content, ai.TextContent{
 415			Text: text,
 416		})
 417	}
 418
 419	for _, tc := range choice.Message.ToolCalls {
 420		toolCallID := tc.ID
 421		if toolCallID == "" {
 422			toolCallID = uuid.NewString()
 423		}
 424		content = append(content, ai.ToolCallContent{
 425			ProviderExecuted: false, // TODO: update when handling other tools
 426			ToolCallID:       toolCallID,
 427			ToolName:         tc.Function.Name,
 428			Input:            tc.Function.Arguments,
 429		})
 430	}
 431	// Handle annotations/citations
 432	for _, annotation := range choice.Message.Annotations {
 433		if annotation.Type == "url_citation" {
 434			content = append(content, ai.SourceContent{
 435				SourceType: ai.SourceTypeURL,
 436				ID:         uuid.NewString(),
 437				URL:        annotation.URLCitation.URL,
 438				Title:      annotation.URLCitation.Title,
 439			})
 440		}
 441	}
 442
 443	completionTokenDetails := response.Usage.CompletionTokensDetails
 444	promptTokenDetails := response.Usage.PromptTokensDetails
 445
 446	// Build provider metadata
 447	providerMetadata := &ProviderMetadata{}
 448	// Add logprobs if available
 449	if len(choice.Logprobs.Content) > 0 {
 450		providerMetadata.Logprobs = choice.Logprobs.Content
 451	}
 452
 453	// Add prediction tokens if available
 454	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 455		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 456			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 457		}
 458		if completionTokenDetails.RejectedPredictionTokens > 0 {
 459			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 460		}
 461	}
 462
 463	return &ai.Response{
 464		Content: content,
 465		Usage: ai.Usage{
 466			InputTokens:     response.Usage.PromptTokens,
 467			OutputTokens:    response.Usage.CompletionTokens,
 468			TotalTokens:     response.Usage.TotalTokens,
 469			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 470			CacheReadTokens: promptTokenDetails.CachedTokens,
 471		},
 472		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
 473		ProviderMetadata: ai.ProviderMetadata{
 474			ProviderOptionsKey: providerMetadata,
 475		},
 476		Warnings: warnings,
 477	}, nil
 478}
 479
 480type toolCall struct {
 481	id          string
 482	name        string
 483	arguments   string
 484	hasFinished bool
 485}
 486
 487// Stream implements ai.LanguageModel.
 488func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 489	params, warnings, err := o.prepareParams(call)
 490	if err != nil {
 491		return nil, err
 492	}
 493
 494	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 495		IncludeUsage: openai.Bool(true),
 496	}
 497
 498	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 499	isActiveText := false
 500	toolCalls := make(map[int64]toolCall)
 501
 502	// Build provider metadata for streaming
 503	streamProviderMetadata := &ProviderMetadata{}
 504	acc := openai.ChatCompletionAccumulator{}
 505	var usage ai.Usage
 506	return func(yield func(ai.StreamPart) bool) {
 507		if len(warnings) > 0 {
 508			if !yield(ai.StreamPart{
 509				Type:     ai.StreamPartTypeWarnings,
 510				Warnings: warnings,
 511			}) {
 512				return
 513			}
 514		}
 515		for stream.Next() {
 516			chunk := stream.Current()
 517			acc.AddChunk(chunk)
 518			if chunk.Usage.TotalTokens > 0 {
 519				// we do this here because the acc does not add prompt details
 520				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 521				promptTokenDetails := chunk.Usage.PromptTokensDetails
 522				usage = ai.Usage{
 523					InputTokens:     chunk.Usage.PromptTokens,
 524					OutputTokens:    chunk.Usage.CompletionTokens,
 525					TotalTokens:     chunk.Usage.TotalTokens,
 526					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 527					CacheReadTokens: promptTokenDetails.CachedTokens,
 528				}
 529
 530				// Add prediction tokens if available
 531				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 532					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 533						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 534					}
 535					if completionTokenDetails.RejectedPredictionTokens > 0 {
 536						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 537					}
 538				}
 539			}
 540			if len(chunk.Choices) == 0 {
 541				continue
 542			}
 543			for _, choice := range chunk.Choices {
 544				switch {
 545				case choice.Delta.Content != "":
 546					if !isActiveText {
 547						isActiveText = true
 548						if !yield(ai.StreamPart{
 549							Type: ai.StreamPartTypeTextStart,
 550							ID:   "0",
 551						}) {
 552							return
 553						}
 554					}
 555					if !yield(ai.StreamPart{
 556						Type:  ai.StreamPartTypeTextDelta,
 557						ID:    "0",
 558						Delta: choice.Delta.Content,
 559					}) {
 560						return
 561					}
 562				case len(choice.Delta.ToolCalls) > 0:
 563					if isActiveText {
 564						isActiveText = false
 565						if !yield(ai.StreamPart{
 566							Type: ai.StreamPartTypeTextEnd,
 567							ID:   "0",
 568						}) {
 569							return
 570						}
 571					}
 572
 573					for _, toolCallDelta := range choice.Delta.ToolCalls {
 574						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 575							if existingToolCall.hasFinished {
 576								continue
 577							}
 578							if toolCallDelta.Function.Arguments != "" {
 579								existingToolCall.arguments += toolCallDelta.Function.Arguments
 580							}
 581							if !yield(ai.StreamPart{
 582								Type:  ai.StreamPartTypeToolInputDelta,
 583								ID:    existingToolCall.id,
 584								Delta: toolCallDelta.Function.Arguments,
 585							}) {
 586								return
 587							}
 588							toolCalls[toolCallDelta.Index] = existingToolCall
 589							if xjson.IsValid(existingToolCall.arguments) {
 590								if !yield(ai.StreamPart{
 591									Type: ai.StreamPartTypeToolInputEnd,
 592									ID:   existingToolCall.id,
 593								}) {
 594									return
 595								}
 596
 597								if !yield(ai.StreamPart{
 598									Type:          ai.StreamPartTypeToolCall,
 599									ID:            existingToolCall.id,
 600									ToolCallName:  existingToolCall.name,
 601									ToolCallInput: existingToolCall.arguments,
 602								}) {
 603									return
 604								}
 605								existingToolCall.hasFinished = true
 606								toolCalls[toolCallDelta.Index] = existingToolCall
 607							}
 608						} else {
 609							// Does not exist
 610							var err error
 611							if toolCallDelta.Type != "function" {
 612								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 613							}
 614							if toolCallDelta.ID == "" {
 615								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 616							}
 617							if toolCallDelta.Function.Name == "" {
 618								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 619							}
 620							if err != nil {
 621								yield(ai.StreamPart{
 622									Type:  ai.StreamPartTypeError,
 623									Error: o.handleError(stream.Err()),
 624								})
 625								return
 626							}
 627
 628							if !yield(ai.StreamPart{
 629								Type:         ai.StreamPartTypeToolInputStart,
 630								ID:           toolCallDelta.ID,
 631								ToolCallName: toolCallDelta.Function.Name,
 632							}) {
 633								return
 634							}
 635							toolCalls[toolCallDelta.Index] = toolCall{
 636								id:        toolCallDelta.ID,
 637								name:      toolCallDelta.Function.Name,
 638								arguments: toolCallDelta.Function.Arguments,
 639							}
 640
 641							exTc := toolCalls[toolCallDelta.Index]
 642							if exTc.arguments != "" {
 643								if !yield(ai.StreamPart{
 644									Type:  ai.StreamPartTypeToolInputDelta,
 645									ID:    exTc.id,
 646									Delta: exTc.arguments,
 647								}) {
 648									return
 649								}
 650								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
 651									if !yield(ai.StreamPart{
 652										Type: ai.StreamPartTypeToolInputEnd,
 653										ID:   toolCallDelta.ID,
 654									}) {
 655										return
 656									}
 657
 658									if !yield(ai.StreamPart{
 659										Type:          ai.StreamPartTypeToolCall,
 660										ID:            exTc.id,
 661										ToolCallName:  exTc.name,
 662										ToolCallInput: exTc.arguments,
 663									}) {
 664										return
 665									}
 666									exTc.hasFinished = true
 667									toolCalls[toolCallDelta.Index] = exTc
 668								}
 669							}
 670							continue
 671						}
 672					}
 673				}
 674			}
 675
 676			// Check for annotations in the delta's raw JSON
 677			for _, choice := range chunk.Choices {
 678				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 679					for _, annotation := range annotations {
 680						if annotation.Type == "url_citation" {
 681							if !yield(ai.StreamPart{
 682								Type:       ai.StreamPartTypeSource,
 683								ID:         uuid.NewString(),
 684								SourceType: ai.SourceTypeURL,
 685								URL:        annotation.URLCitation.URL,
 686								Title:      annotation.URLCitation.Title,
 687							}) {
 688								return
 689							}
 690						}
 691					}
 692				}
 693			}
 694		}
 695		err := stream.Err()
 696		if err == nil || errors.Is(err, io.EOF) {
 697			// finished
 698			if isActiveText {
 699				isActiveText = false
 700				if !yield(ai.StreamPart{
 701					Type: ai.StreamPartTypeTextEnd,
 702					ID:   "0",
 703				}) {
 704					return
 705				}
 706			}
 707
 708			// Add logprobs if available
 709			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 710				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
 711			}
 712
 713			// Handle annotations/citations from accumulated response
 714			if len(acc.Choices) > 0 {
 715				for _, annotation := range acc.Choices[0].Message.Annotations {
 716					if annotation.Type == "url_citation" {
 717						if !yield(ai.StreamPart{
 718							Type:       ai.StreamPartTypeSource,
 719							ID:         acc.ID,
 720							SourceType: ai.SourceTypeURL,
 721							URL:        annotation.URLCitation.URL,
 722							Title:      annotation.URLCitation.Title,
 723						}) {
 724							return
 725						}
 726					}
 727				}
 728			}
 729
 730			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 731			yield(ai.StreamPart{
 732				Type:         ai.StreamPartTypeFinish,
 733				Usage:        usage,
 734				FinishReason: finishReason,
 735				ProviderMetadata: ai.ProviderMetadata{
 736					ProviderOptionsKey: streamProviderMetadata,
 737				},
 738			})
 739			return
 740		} else {
 741			yield(ai.StreamPart{
 742				Type:  ai.StreamPartTypeError,
 743				Error: o.handleError(err),
 744			})
 745			return
 746		}
 747	}, nil
 748}
 749
 750func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 751	switch finishReason {
 752	case "stop":
 753		return ai.FinishReasonStop
 754	case "length":
 755		return ai.FinishReasonLength
 756	case "content_filter":
 757		return ai.FinishReasonContentFilter
 758	case "function_call", "tool_calls":
 759		return ai.FinishReasonToolCalls
 760	default:
 761		return ai.FinishReasonUnknown
 762	}
 763}
 764
 765func isReasoningModel(modelID string) bool {
 766	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 767}
 768
 769func isSearchPreviewModel(modelID string) bool {
 770	return strings.Contains(modelID, "search-preview")
 771}
 772
 773func supportsFlexProcessing(modelID string) bool {
 774	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 775}
 776
 777func supportsPriorityProcessing(modelID string) bool {
 778	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 779		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 780		strings.HasPrefix(modelID, "o4-mini")
 781}
 782
 783func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 784	for _, tool := range tools {
 785		if tool.GetType() == ai.ToolTypeFunction {
 786			ft, ok := tool.(ai.FunctionTool)
 787			if !ok {
 788				continue
 789			}
 790			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 791				OfFunction: &openai.ChatCompletionFunctionToolParam{
 792					Function: shared.FunctionDefinitionParam{
 793						Name:        ft.Name,
 794						Description: param.NewOpt(ft.Description),
 795						Parameters:  openai.FunctionParameters(ft.InputSchema),
 796						Strict:      param.NewOpt(false),
 797					},
 798					Type: "function",
 799				},
 800			})
 801			continue
 802		}
 803
 804		// TODO: handle provider tool calls
 805		warnings = append(warnings, ai.CallWarning{
 806			Type:    ai.CallWarningTypeUnsupportedTool,
 807			Tool:    tool,
 808			Message: "tool is not supported",
 809		})
 810	}
 811	if toolChoice == nil {
 812		return openAiTools, openAiToolChoice, warnings
 813	}
 814
 815	switch *toolChoice {
 816	case ai.ToolChoiceAuto:
 817		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 818			OfAuto: param.NewOpt("auto"),
 819		}
 820	case ai.ToolChoiceNone:
 821		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 822			OfAuto: param.NewOpt("none"),
 823		}
 824	default:
 825		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 826			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 827				Type: "function",
 828				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 829					Name: string(*toolChoice),
 830				},
 831			},
 832		}
 833	}
 834	return openAiTools, openAiToolChoice, warnings
 835}
 836
 837func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 838	var messages []openai.ChatCompletionMessageParamUnion
 839	var warnings []ai.CallWarning
 840	for _, msg := range prompt {
 841		switch msg.Role {
 842		case ai.MessageRoleSystem:
 843			var systemPromptParts []string
 844			for _, c := range msg.Content {
 845				if c.GetType() != ai.ContentTypeText {
 846					warnings = append(warnings, ai.CallWarning{
 847						Type:    ai.CallWarningTypeOther,
 848						Message: "system prompt can only have text content",
 849					})
 850					continue
 851				}
 852				textPart, ok := ai.AsContentType[ai.TextPart](c)
 853				if !ok {
 854					warnings = append(warnings, ai.CallWarning{
 855						Type:    ai.CallWarningTypeOther,
 856						Message: "system prompt text part does not have the right type",
 857					})
 858					continue
 859				}
 860				text := textPart.Text
 861				if strings.TrimSpace(text) != "" {
 862					systemPromptParts = append(systemPromptParts, textPart.Text)
 863				}
 864			}
 865			if len(systemPromptParts) == 0 {
 866				warnings = append(warnings, ai.CallWarning{
 867					Type:    ai.CallWarningTypeOther,
 868					Message: "system prompt has no text parts",
 869				})
 870				continue
 871			}
 872			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 873		case ai.MessageRoleUser:
 874			// simple user message just text content
 875			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 876				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 877				if !ok {
 878					warnings = append(warnings, ai.CallWarning{
 879						Type:    ai.CallWarningTypeOther,
 880						Message: "user message text part does not have the right type",
 881					})
 882					continue
 883				}
 884				messages = append(messages, openai.UserMessage(textPart.Text))
 885				continue
 886			}
 887			// text content and attachments
 888			// for now we only support image content later we need to check
 889			// TODO: add the supported media types to the language model so we
 890			//  can use that to validate the data here.
 891			var content []openai.ChatCompletionContentPartUnionParam
 892			for _, c := range msg.Content {
 893				switch c.GetType() {
 894				case ai.ContentTypeText:
 895					textPart, ok := ai.AsContentType[ai.TextPart](c)
 896					if !ok {
 897						warnings = append(warnings, ai.CallWarning{
 898							Type:    ai.CallWarningTypeOther,
 899							Message: "user message text part does not have the right type",
 900						})
 901						continue
 902					}
 903					content = append(content, openai.ChatCompletionContentPartUnionParam{
 904						OfText: &openai.ChatCompletionContentPartTextParam{
 905							Text: textPart.Text,
 906						},
 907					})
 908				case ai.ContentTypeFile:
 909					filePart, ok := ai.AsContentType[ai.FilePart](c)
 910					if !ok {
 911						warnings = append(warnings, ai.CallWarning{
 912							Type:    ai.CallWarningTypeOther,
 913							Message: "user message file part does not have the right type",
 914						})
 915						continue
 916					}
 917
 918					switch {
 919					case strings.HasPrefix(filePart.MediaType, "image/"):
 920						// Handle image files
 921						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 922						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 923						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 924
 925						// Check for provider-specific options like image detail
 926						if providerOptions, ok := filePart.ProviderOptions[ProviderOptionsKey]; ok {
 927							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
 928								imageURL.Detail = detail.ImageDetail
 929							}
 930						}
 931
 932						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 933						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 934
 935					case filePart.MediaType == "audio/wav":
 936						// Handle WAV audio files
 937						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 938						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 939							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 940								Data:   base64Encoded,
 941								Format: "wav",
 942							},
 943						}
 944						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 945
 946					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 947						// Handle MP3 audio files
 948						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 949						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 950							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 951								Data:   base64Encoded,
 952								Format: "mp3",
 953							},
 954						}
 955						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 956
 957					case filePart.MediaType == "application/pdf":
 958						// Handle PDF files
 959						dataStr := string(filePart.Data)
 960
 961						// Check if data looks like a file ID (starts with "file-")
 962						if strings.HasPrefix(dataStr, "file-") {
 963							fileBlock := openai.ChatCompletionContentPartFileParam{
 964								File: openai.ChatCompletionContentPartFileFileParam{
 965									FileID: param.NewOpt(dataStr),
 966								},
 967							}
 968							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 969						} else {
 970							// Handle as base64 data
 971							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 972							data := "data:application/pdf;base64," + base64Encoded
 973
 974							filename := filePart.Filename
 975							if filename == "" {
 976								// Generate default filename based on content index
 977								filename = fmt.Sprintf("part-%d.pdf", len(content))
 978							}
 979
 980							fileBlock := openai.ChatCompletionContentPartFileParam{
 981								File: openai.ChatCompletionContentPartFileFileParam{
 982									Filename: param.NewOpt(filename),
 983									FileData: param.NewOpt(data),
 984								},
 985							}
 986							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 987						}
 988
 989					default:
 990						warnings = append(warnings, ai.CallWarning{
 991							Type:    ai.CallWarningTypeOther,
 992							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
 993						})
 994					}
 995				}
 996			}
 997			messages = append(messages, openai.UserMessage(content))
 998		case ai.MessageRoleAssistant:
 999			// simple assistant message just text content
1000			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1001				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1002				if !ok {
1003					warnings = append(warnings, ai.CallWarning{
1004						Type:    ai.CallWarningTypeOther,
1005						Message: "assistant message text part does not have the right type",
1006					})
1007					continue
1008				}
1009				messages = append(messages, openai.AssistantMessage(textPart.Text))
1010				continue
1011			}
1012			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1013				Role: "assistant",
1014			}
1015			for _, c := range msg.Content {
1016				switch c.GetType() {
1017				case ai.ContentTypeText:
1018					textPart, ok := ai.AsContentType[ai.TextPart](c)
1019					if !ok {
1020						warnings = append(warnings, ai.CallWarning{
1021							Type:    ai.CallWarningTypeOther,
1022							Message: "assistant message text part does not have the right type",
1023						})
1024						continue
1025					}
1026					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1027						OfString: param.NewOpt(textPart.Text),
1028					}
1029				case ai.ContentTypeToolCall:
1030					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1031					if !ok {
1032						warnings = append(warnings, ai.CallWarning{
1033							Type:    ai.CallWarningTypeOther,
1034							Message: "assistant message tool part does not have the right type",
1035						})
1036						continue
1037					}
1038					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1039						openai.ChatCompletionMessageToolCallUnionParam{
1040							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1041								ID:   toolCallPart.ToolCallID,
1042								Type: "function",
1043								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1044									Name:      toolCallPart.ToolName,
1045									Arguments: toolCallPart.Input,
1046								},
1047							},
1048						})
1049				}
1050			}
1051			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1052				OfAssistant: &assistantMsg,
1053			})
1054		case ai.MessageRoleTool:
1055			for _, c := range msg.Content {
1056				if c.GetType() != ai.ContentTypeToolResult {
1057					warnings = append(warnings, ai.CallWarning{
1058						Type:    ai.CallWarningTypeOther,
1059						Message: "tool message can only have tool result content",
1060					})
1061					continue
1062				}
1063
1064				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1065				if !ok {
1066					warnings = append(warnings, ai.CallWarning{
1067						Type:    ai.CallWarningTypeOther,
1068						Message: "tool message result part does not have the right type",
1069					})
1070					continue
1071				}
1072
1073				switch toolResultPart.Output.GetType() {
1074				case ai.ToolResultContentTypeText:
1075					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1076					if !ok {
1077						warnings = append(warnings, ai.CallWarning{
1078							Type:    ai.CallWarningTypeOther,
1079							Message: "tool result output does not have the right type",
1080						})
1081						continue
1082					}
1083					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1084				case ai.ToolResultContentTypeError:
1085					// TODO: check if better handling is needed
1086					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1087					if !ok {
1088						warnings = append(warnings, ai.CallWarning{
1089							Type:    ai.CallWarningTypeOther,
1090							Message: "tool result output does not have the right type",
1091						})
1092						continue
1093					}
1094					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1095				}
1096			}
1097		}
1098	}
1099	return messages, warnings
1100}
1101
1102// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1103func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1104	var annotations []openai.ChatCompletionMessageAnnotation
1105
1106	// Parse the raw JSON to extract annotations
1107	var deltaData map[string]any
1108	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1109		return annotations
1110	}
1111
1112	// Check if annotations exist in the delta
1113	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1114		for _, annotationData := range annotationsData {
1115			if annotationMap, ok := annotationData.(map[string]any); ok {
1116				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1117					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1118						annotation := openai.ChatCompletionMessageAnnotation{
1119							Type: "url_citation",
1120							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1121								URL:   urlCitationData["url"].(string),
1122								Title: urlCitationData["title"].(string),
1123							},
1124						}
1125						annotations = append(annotations, annotation)
1126					}
1127				}
1128			}
1129		}
1130	}
1131
1132	return annotations
1133}