1package openai
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"io"
  10	"maps"
  11	"strings"
  12
  13	"github.com/charmbracelet/ai"
  14	"github.com/charmbracelet/ai/internal/jsonext"
  15	"github.com/google/uuid"
  16	"github.com/openai/openai-go/v2"
  17	"github.com/openai/openai-go/v2/option"
  18	"github.com/openai/openai-go/v2/packages/param"
  19	"github.com/openai/openai-go/v2/shared"
  20)
  21
  22type provider struct {
  23	options options
  24}
  25
  26type options struct {
  27	baseURL      string
  28	apiKey       string
  29	organization string
  30	project      string
  31	name         string
  32	headers      map[string]string
  33	client       option.HTTPClient
  34}
  35
  36type Option = func(*options)
  37
  38func New(opts ...Option) ai.Provider {
  39	options := options{
  40		headers: map[string]string{},
  41	}
  42	for _, o := range opts {
  43		o(&options)
  44	}
  45
  46	if options.baseURL == "" {
  47		options.baseURL = "https://api.openai.com/v1"
  48	}
  49
  50	if options.name == "" {
  51		options.name = "openai"
  52	}
  53
  54	if options.organization != "" {
  55		options.headers["OpenAi-Organization"] = options.organization
  56	}
  57
  58	if options.project != "" {
  59		options.headers["OpenAi-Project"] = options.project
  60	}
  61
  62	return &provider{
  63		options: options,
  64	}
  65}
  66
  67func WithBaseURL(baseURL string) Option {
  68	return func(o *options) {
  69		o.baseURL = baseURL
  70	}
  71}
  72
  73func WithAPIKey(apiKey string) Option {
  74	return func(o *options) {
  75		o.apiKey = apiKey
  76	}
  77}
  78
  79func WithOrganization(organization string) Option {
  80	return func(o *options) {
  81		o.organization = organization
  82	}
  83}
  84
  85func WithProject(project string) Option {
  86	return func(o *options) {
  87		o.project = project
  88	}
  89}
  90
  91func WithName(name string) Option {
  92	return func(o *options) {
  93		o.name = name
  94	}
  95}
  96
  97func WithHeaders(headers map[string]string) Option {
  98	return func(o *options) {
  99		maps.Copy(o.headers, headers)
 100	}
 101}
 102
 103func WithHTTPClient(client option.HTTPClient) Option {
 104	return func(o *options) {
 105		o.client = client
 106	}
 107}
 108
 109// LanguageModel implements ai.Provider.
 110func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 111	openaiClientOptions := []option.RequestOption{}
 112	if o.options.apiKey != "" {
 113		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 114	}
 115	if o.options.baseURL != "" {
 116		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 117	}
 118
 119	for key, value := range o.options.headers {
 120		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 121	}
 122
 123	if o.options.client != nil {
 124		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 125	}
 126
 127	return languageModel{
 128		modelID:  modelID,
 129		provider: fmt.Sprintf("%s.chat", o.options.name),
 130		options:  o.options,
 131		client:   openai.NewClient(openaiClientOptions...),
 132	}, nil
 133}
 134
 135type languageModel struct {
 136	provider string
 137	modelID  string
 138	client   openai.Client
 139	options  options
 140}
 141
 142// Model implements ai.LanguageModel.
 143func (o languageModel) Model() string {
 144	return o.modelID
 145}
 146
 147// Provider implements ai.LanguageModel.
 148func (o languageModel) Provider() string {
 149	return o.provider
 150}
 151
 152func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 153	params := &openai.ChatCompletionNewParams{}
 154	messages, warnings := toPrompt(call.Prompt)
 155	providerOptions := &providerOptions{}
 156	if v, ok := call.ProviderOptions["openai"]; ok {
 157		err := ai.ParseOptions(v, providerOptions)
 158		if err != nil {
 159			return nil, nil, err
 160		}
 161	}
 162	if call.TopK != nil {
 163		warnings = append(warnings, ai.CallWarning{
 164			Type:    ai.CallWarningTypeUnsupportedSetting,
 165			Setting: "top_k",
 166		})
 167	}
 168	params.Messages = messages
 169	params.Model = o.modelID
 170	if providerOptions.LogitBias != nil {
 171		params.LogitBias = providerOptions.LogitBias
 172	}
 173	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 174		providerOptions.LogProbs = nil
 175	}
 176	if providerOptions.LogProbs != nil {
 177		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 178	}
 179	if providerOptions.TopLogProbs != nil {
 180		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 181	}
 182	if providerOptions.User != nil {
 183		params.User = param.NewOpt(*providerOptions.User)
 184	}
 185	if providerOptions.ParallelToolCalls != nil {
 186		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 187	}
 188
 189	if call.MaxOutputTokens != nil {
 190		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 191	}
 192	if call.Temperature != nil {
 193		params.Temperature = param.NewOpt(*call.Temperature)
 194	}
 195	if call.TopP != nil {
 196		params.TopP = param.NewOpt(*call.TopP)
 197	}
 198	if call.FrequencyPenalty != nil {
 199		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 200	}
 201	if call.PresencePenalty != nil {
 202		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 203	}
 204
 205	if providerOptions.MaxCompletionTokens != nil {
 206		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 207	}
 208
 209	if providerOptions.TextVerbosity != nil {
 210		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 211	}
 212	if providerOptions.Prediction != nil {
 213		// Convert map[string]any to ChatCompletionPredictionContentParam
 214		if content, ok := providerOptions.Prediction["content"]; ok {
 215			if contentStr, ok := content.(string); ok {
 216				params.Prediction = openai.ChatCompletionPredictionContentParam{
 217					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 218						OfString: param.NewOpt(contentStr),
 219					},
 220				}
 221			}
 222		}
 223	}
 224	if providerOptions.Store != nil {
 225		params.Store = param.NewOpt(*providerOptions.Store)
 226	}
 227	if providerOptions.Metadata != nil {
 228		// Convert map[string]any to map[string]string
 229		metadata := make(map[string]string)
 230		for k, v := range providerOptions.Metadata {
 231			if str, ok := v.(string); ok {
 232				metadata[k] = str
 233			}
 234		}
 235		params.Metadata = metadata
 236	}
 237	if providerOptions.PromptCacheKey != nil {
 238		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 239	}
 240	if providerOptions.SafetyIdentifier != nil {
 241		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 242	}
 243	if providerOptions.ServiceTier != nil {
 244		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 245	}
 246
 247	if providerOptions.ReasoningEffort != nil {
 248		switch *providerOptions.ReasoningEffort {
 249		case reasoningEffortMinimal:
 250			params.ReasoningEffort = shared.ReasoningEffortMinimal
 251		case reasoningEffortLow:
 252			params.ReasoningEffort = shared.ReasoningEffortLow
 253		case reasoningEffortMedium:
 254			params.ReasoningEffort = shared.ReasoningEffortMedium
 255		case reasoningEffortHigh:
 256			params.ReasoningEffort = shared.ReasoningEffortHigh
 257		default:
 258			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 259		}
 260	}
 261
 262	if isReasoningModel(o.modelID) {
 263		// remove unsupported settings for reasoning models
 264		// see https://platform.openai.com/docs/guides/reasoning#limitations
 265		if call.Temperature != nil {
 266			params.Temperature = param.Opt[float64]{}
 267			warnings = append(warnings, ai.CallWarning{
 268				Type:    ai.CallWarningTypeUnsupportedSetting,
 269				Setting: "temperature",
 270				Details: "temperature is not supported for reasoning models",
 271			})
 272		}
 273		if call.TopP != nil {
 274			params.TopP = param.Opt[float64]{}
 275			warnings = append(warnings, ai.CallWarning{
 276				Type:    ai.CallWarningTypeUnsupportedSetting,
 277				Setting: "TopP",
 278				Details: "TopP is not supported for reasoning models",
 279			})
 280		}
 281		if call.FrequencyPenalty != nil {
 282			params.FrequencyPenalty = param.Opt[float64]{}
 283			warnings = append(warnings, ai.CallWarning{
 284				Type:    ai.CallWarningTypeUnsupportedSetting,
 285				Setting: "FrequencyPenalty",
 286				Details: "FrequencyPenalty is not supported for reasoning models",
 287			})
 288		}
 289		if call.PresencePenalty != nil {
 290			params.PresencePenalty = param.Opt[float64]{}
 291			warnings = append(warnings, ai.CallWarning{
 292				Type:    ai.CallWarningTypeUnsupportedSetting,
 293				Setting: "PresencePenalty",
 294				Details: "PresencePenalty is not supported for reasoning models",
 295			})
 296		}
 297		if providerOptions.LogitBias != nil {
 298			params.LogitBias = nil
 299			warnings = append(warnings, ai.CallWarning{
 300				Type:    ai.CallWarningTypeUnsupportedSetting,
 301				Setting: "LogitBias",
 302				Message: "LogitBias is not supported for reasoning models",
 303			})
 304		}
 305		if providerOptions.LogProbs != nil {
 306			params.Logprobs = param.Opt[bool]{}
 307			warnings = append(warnings, ai.CallWarning{
 308				Type:    ai.CallWarningTypeUnsupportedSetting,
 309				Setting: "Logprobs",
 310				Message: "Logprobs is not supported for reasoning models",
 311			})
 312		}
 313		if providerOptions.TopLogProbs != nil {
 314			params.TopLogprobs = param.Opt[int64]{}
 315			warnings = append(warnings, ai.CallWarning{
 316				Type:    ai.CallWarningTypeUnsupportedSetting,
 317				Setting: "TopLogprobs",
 318				Message: "TopLogprobs is not supported for reasoning models",
 319			})
 320		}
 321
 322		// reasoning models use max_completion_tokens instead of max_tokens
 323		if call.MaxOutputTokens != nil {
 324			if providerOptions.MaxCompletionTokens == nil {
 325				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 326			}
 327			params.MaxTokens = param.Opt[int64]{}
 328		}
 329	}
 330
 331	// Handle search preview models
 332	if isSearchPreviewModel(o.modelID) {
 333		if call.Temperature != nil {
 334			params.Temperature = param.Opt[float64]{}
 335			warnings = append(warnings, ai.CallWarning{
 336				Type:    ai.CallWarningTypeUnsupportedSetting,
 337				Setting: "temperature",
 338				Details: "temperature is not supported for the search preview models and has been removed.",
 339			})
 340		}
 341	}
 342
 343	// Handle service tier validation
 344	if providerOptions.ServiceTier != nil {
 345		serviceTier := *providerOptions.ServiceTier
 346		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 347			params.ServiceTier = ""
 348			warnings = append(warnings, ai.CallWarning{
 349				Type:    ai.CallWarningTypeUnsupportedSetting,
 350				Setting: "ServiceTier",
 351				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 352			})
 353		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 354			params.ServiceTier = ""
 355			warnings = append(warnings, ai.CallWarning{
 356				Type:    ai.CallWarningTypeUnsupportedSetting,
 357				Setting: "ServiceTier",
 358				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 359			})
 360		}
 361	}
 362
 363	if len(call.Tools) > 0 {
 364		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 365		params.Tools = tools
 366		if toolChoice != nil {
 367			params.ToolChoice = *toolChoice
 368		}
 369		warnings = append(warnings, toolWarnings...)
 370	}
 371	return params, warnings, nil
 372}
 373
 374func (o languageModel) handleError(err error) error {
 375	var apiErr *openai.Error
 376	if errors.As(err, &apiErr) {
 377		requestDump := apiErr.DumpRequest(true)
 378		responseDump := apiErr.DumpResponse(true)
 379		headers := map[string]string{}
 380		for k, h := range apiErr.Response.Header {
 381			v := h[len(h)-1]
 382			headers[strings.ToLower(k)] = v
 383		}
 384		return ai.NewAPICallError(
 385			apiErr.Message,
 386			apiErr.Request.URL.String(),
 387			string(requestDump),
 388			apiErr.StatusCode,
 389			headers,
 390			string(responseDump),
 391			apiErr,
 392			false,
 393		)
 394	}
 395	return err
 396}
 397
 398// Generate implements ai.LanguageModel.
 399func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 400	params, warnings, err := o.prepareParams(call)
 401	if err != nil {
 402		return nil, err
 403	}
 404	response, err := o.client.Chat.Completions.New(ctx, *params)
 405	if err != nil {
 406		return nil, o.handleError(err)
 407	}
 408
 409	if len(response.Choices) == 0 {
 410		return nil, errors.New("no response generated")
 411	}
 412	choice := response.Choices[0]
 413	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 414	text := choice.Message.Content
 415	if text != "" {
 416		content = append(content, ai.TextContent{
 417			Text: text,
 418		})
 419	}
 420
 421	for _, tc := range choice.Message.ToolCalls {
 422		toolCallID := tc.ID
 423		if toolCallID == "" {
 424			toolCallID = uuid.NewString()
 425		}
 426		content = append(content, ai.ToolCallContent{
 427			ProviderExecuted: false, // TODO: update when handling other tools
 428			ToolCallID:       toolCallID,
 429			ToolName:         tc.Function.Name,
 430			Input:            tc.Function.Arguments,
 431		})
 432	}
 433	// Handle annotations/citations
 434	for _, annotation := range choice.Message.Annotations {
 435		if annotation.Type == "url_citation" {
 436			content = append(content, ai.SourceContent{
 437				SourceType: ai.SourceTypeURL,
 438				ID:         uuid.NewString(),
 439				URL:        annotation.URLCitation.URL,
 440				Title:      annotation.URLCitation.Title,
 441			})
 442		}
 443	}
 444
 445	completionTokenDetails := response.Usage.CompletionTokensDetails
 446	promptTokenDetails := response.Usage.PromptTokensDetails
 447
 448	// Build provider metadata
 449	providerMetadata := ai.ProviderMetadata{
 450		"openai": make(map[string]any),
 451	}
 452
 453	// Add logprobs if available
 454	if len(choice.Logprobs.Content) > 0 {
 455		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
 456	}
 457
 458	// Add prediction tokens if available
 459	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 460		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 461			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 462		}
 463		if completionTokenDetails.RejectedPredictionTokens > 0 {
 464			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 465		}
 466	}
 467
 468	return &ai.Response{
 469		Content: content,
 470		Usage: ai.Usage{
 471			InputTokens:     response.Usage.PromptTokens,
 472			OutputTokens:    response.Usage.CompletionTokens,
 473			TotalTokens:     response.Usage.TotalTokens,
 474			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 475			CacheReadTokens: promptTokenDetails.CachedTokens,
 476		},
 477		FinishReason:     mapOpenAiFinishReason(choice.FinishReason),
 478		ProviderMetadata: providerMetadata,
 479		Warnings:         warnings,
 480	}, nil
 481}
 482
 483type toolCall struct {
 484	id          string
 485	name        string
 486	arguments   string
 487	hasFinished bool
 488}
 489
 490// Stream implements ai.LanguageModel.
 491func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 492	params, warnings, err := o.prepareParams(call)
 493	if err != nil {
 494		return nil, err
 495	}
 496
 497	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 498		IncludeUsage: openai.Bool(true),
 499	}
 500
 501	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 502	isActiveText := false
 503	toolCalls := make(map[int64]toolCall)
 504
 505	// Build provider metadata for streaming
 506	streamProviderMetadata := ai.ProviderMetadata{
 507		"openai": make(map[string]any),
 508	}
 509
 510	acc := openai.ChatCompletionAccumulator{}
 511	var usage ai.Usage
 512	return func(yield func(ai.StreamPart) bool) {
 513		if len(warnings) > 0 {
 514			if !yield(ai.StreamPart{
 515				Type:     ai.StreamPartTypeWarnings,
 516				Warnings: warnings,
 517			}) {
 518				return
 519			}
 520		}
 521		for stream.Next() {
 522			chunk := stream.Current()
 523			acc.AddChunk(chunk)
 524			if chunk.Usage.TotalTokens > 0 {
 525				// we do this here because the acc does not add prompt details
 526				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 527				promptTokenDetails := chunk.Usage.PromptTokensDetails
 528				usage = ai.Usage{
 529					InputTokens:     chunk.Usage.PromptTokens,
 530					OutputTokens:    chunk.Usage.CompletionTokens,
 531					TotalTokens:     chunk.Usage.TotalTokens,
 532					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 533					CacheReadTokens: promptTokenDetails.CachedTokens,
 534				}
 535
 536				// Add prediction tokens if available
 537				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 538					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 539						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 540					}
 541					if completionTokenDetails.RejectedPredictionTokens > 0 {
 542						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 543					}
 544				}
 545			}
 546			if len(chunk.Choices) == 0 {
 547				continue
 548			}
 549			for _, choice := range chunk.Choices {
 550				switch {
 551				case choice.Delta.Content != "":
 552					if !isActiveText {
 553						isActiveText = true
 554						if !yield(ai.StreamPart{
 555							Type: ai.StreamPartTypeTextStart,
 556							ID:   "0",
 557						}) {
 558							return
 559						}
 560					}
 561					if !yield(ai.StreamPart{
 562						Type:  ai.StreamPartTypeTextDelta,
 563						ID:    "0",
 564						Delta: choice.Delta.Content,
 565					}) {
 566						return
 567					}
 568				case len(choice.Delta.ToolCalls) > 0:
 569					if isActiveText {
 570						isActiveText = false
 571						if !yield(ai.StreamPart{
 572							Type: ai.StreamPartTypeTextEnd,
 573							ID:   "0",
 574						}) {
 575							return
 576						}
 577					}
 578
 579					for _, toolCallDelta := range choice.Delta.ToolCalls {
 580						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 581							if existingToolCall.hasFinished {
 582								continue
 583							}
 584							if toolCallDelta.Function.Arguments != "" {
 585								existingToolCall.arguments += toolCallDelta.Function.Arguments
 586							}
 587							if !yield(ai.StreamPart{
 588								Type:  ai.StreamPartTypeToolInputDelta,
 589								ID:    existingToolCall.id,
 590								Delta: toolCallDelta.Function.Arguments,
 591							}) {
 592								return
 593							}
 594							toolCalls[toolCallDelta.Index] = existingToolCall
 595							if jsonext.IsValidJSON(existingToolCall.arguments) {
 596								if !yield(ai.StreamPart{
 597									Type: ai.StreamPartTypeToolInputEnd,
 598									ID:   existingToolCall.id,
 599								}) {
 600									return
 601								}
 602
 603								if !yield(ai.StreamPart{
 604									Type:          ai.StreamPartTypeToolCall,
 605									ID:            existingToolCall.id,
 606									ToolCallName:  existingToolCall.name,
 607									ToolCallInput: existingToolCall.arguments,
 608								}) {
 609									return
 610								}
 611								existingToolCall.hasFinished = true
 612								toolCalls[toolCallDelta.Index] = existingToolCall
 613							}
 614						} else {
 615							// Does not exist
 616							var err error
 617							if toolCallDelta.Type != "function" {
 618								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 619							}
 620							if toolCallDelta.ID == "" {
 621								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 622							}
 623							if toolCallDelta.Function.Name == "" {
 624								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 625							}
 626							if err != nil {
 627								yield(ai.StreamPart{
 628									Type:  ai.StreamPartTypeError,
 629									Error: o.handleError(stream.Err()),
 630								})
 631								return
 632							}
 633
 634							if !yield(ai.StreamPart{
 635								Type:         ai.StreamPartTypeToolInputStart,
 636								ID:           toolCallDelta.ID,
 637								ToolCallName: toolCallDelta.Function.Name,
 638							}) {
 639								return
 640							}
 641							toolCalls[toolCallDelta.Index] = toolCall{
 642								id:        toolCallDelta.ID,
 643								name:      toolCallDelta.Function.Name,
 644								arguments: toolCallDelta.Function.Arguments,
 645							}
 646
 647							exTc := toolCalls[toolCallDelta.Index]
 648							if exTc.arguments != "" {
 649								if !yield(ai.StreamPart{
 650									Type:  ai.StreamPartTypeToolInputDelta,
 651									ID:    exTc.id,
 652									Delta: exTc.arguments,
 653								}) {
 654									return
 655								}
 656								if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) {
 657									if !yield(ai.StreamPart{
 658										Type: ai.StreamPartTypeToolInputEnd,
 659										ID:   toolCallDelta.ID,
 660									}) {
 661										return
 662									}
 663
 664									if !yield(ai.StreamPart{
 665										Type:          ai.StreamPartTypeToolCall,
 666										ID:            exTc.id,
 667										ToolCallName:  exTc.name,
 668										ToolCallInput: exTc.arguments,
 669									}) {
 670										return
 671									}
 672									exTc.hasFinished = true
 673									toolCalls[toolCallDelta.Index] = exTc
 674								}
 675							}
 676							continue
 677						}
 678					}
 679				}
 680			}
 681
 682			// Check for annotations in the delta's raw JSON
 683			for _, choice := range chunk.Choices {
 684				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 685					for _, annotation := range annotations {
 686						if annotation.Type == "url_citation" {
 687							if !yield(ai.StreamPart{
 688								Type:       ai.StreamPartTypeSource,
 689								ID:         uuid.NewString(),
 690								SourceType: ai.SourceTypeURL,
 691								URL:        annotation.URLCitation.URL,
 692								Title:      annotation.URLCitation.Title,
 693							}) {
 694								return
 695							}
 696						}
 697					}
 698				}
 699			}
 700		}
 701		err := stream.Err()
 702		if err == nil || errors.Is(err, io.EOF) {
 703			// finished
 704			if isActiveText {
 705				isActiveText = false
 706				if !yield(ai.StreamPart{
 707					Type: ai.StreamPartTypeTextEnd,
 708					ID:   "0",
 709				}) {
 710					return
 711				}
 712			}
 713
 714			// Add logprobs if available
 715			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 716				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
 717			}
 718
 719			// Handle annotations/citations from accumulated response
 720			if len(acc.Choices) > 0 {
 721				for _, annotation := range acc.Choices[0].Message.Annotations {
 722					if annotation.Type == "url_citation" {
 723						if !yield(ai.StreamPart{
 724							Type:       ai.StreamPartTypeSource,
 725							ID:         acc.ID,
 726							SourceType: ai.SourceTypeURL,
 727							URL:        annotation.URLCitation.URL,
 728							Title:      annotation.URLCitation.Title,
 729						}) {
 730							return
 731						}
 732					}
 733				}
 734			}
 735
 736			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 737			yield(ai.StreamPart{
 738				Type:             ai.StreamPartTypeFinish,
 739				Usage:            usage,
 740				FinishReason:     finishReason,
 741				ProviderMetadata: streamProviderMetadata,
 742			})
 743			return
 744		} else {
 745			yield(ai.StreamPart{
 746				Type:  ai.StreamPartTypeError,
 747				Error: o.handleError(err),
 748			})
 749			return
 750		}
 751	}, nil
 752}
 753
 754func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 755	switch finishReason {
 756	case "stop":
 757		return ai.FinishReasonStop
 758	case "length":
 759		return ai.FinishReasonLength
 760	case "content_filter":
 761		return ai.FinishReasonContentFilter
 762	case "function_call", "tool_calls":
 763		return ai.FinishReasonToolCalls
 764	default:
 765		return ai.FinishReasonUnknown
 766	}
 767}
 768
 769func isReasoningModel(modelID string) bool {
 770	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 771}
 772
 773func isSearchPreviewModel(modelID string) bool {
 774	return strings.Contains(modelID, "search-preview")
 775}
 776
 777func supportsFlexProcessing(modelID string) bool {
 778	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 779}
 780
 781func supportsPriorityProcessing(modelID string) bool {
 782	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 783		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 784		strings.HasPrefix(modelID, "o4-mini")
 785}
 786
 787func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 788	for _, tool := range tools {
 789		if tool.GetType() == ai.ToolTypeFunction {
 790			ft, ok := tool.(ai.FunctionTool)
 791			if !ok {
 792				continue
 793			}
 794			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 795				OfFunction: &openai.ChatCompletionFunctionToolParam{
 796					Function: shared.FunctionDefinitionParam{
 797						Name:        ft.Name,
 798						Description: param.NewOpt(ft.Description),
 799						Parameters:  openai.FunctionParameters(ft.InputSchema),
 800						Strict:      param.NewOpt(false),
 801					},
 802					Type: "function",
 803				},
 804			})
 805			continue
 806		}
 807
 808		// TODO: handle provider tool calls
 809		warnings = append(warnings, ai.CallWarning{
 810			Type:    ai.CallWarningTypeUnsupportedTool,
 811			Tool:    tool,
 812			Message: "tool is not supported",
 813		})
 814	}
 815	if toolChoice == nil {
 816		return openAiTools, openAiToolChoice, warnings
 817	}
 818
 819	switch *toolChoice {
 820	case ai.ToolChoiceAuto:
 821		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 822			OfAuto: param.NewOpt("auto"),
 823		}
 824	case ai.ToolChoiceNone:
 825		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 826			OfAuto: param.NewOpt("none"),
 827		}
 828	default:
 829		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 830			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 831				Type: "function",
 832				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 833					Name: string(*toolChoice),
 834				},
 835			},
 836		}
 837	}
 838	return openAiTools, openAiToolChoice, warnings
 839}
 840
 841func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 842	var messages []openai.ChatCompletionMessageParamUnion
 843	var warnings []ai.CallWarning
 844	for _, msg := range prompt {
 845		switch msg.Role {
 846		case ai.MessageRoleSystem:
 847			var systemPromptParts []string
 848			for _, c := range msg.Content {
 849				if c.GetType() != ai.ContentTypeText {
 850					warnings = append(warnings, ai.CallWarning{
 851						Type:    ai.CallWarningTypeOther,
 852						Message: "system prompt can only have text content",
 853					})
 854					continue
 855				}
 856				textPart, ok := ai.AsContentType[ai.TextPart](c)
 857				if !ok {
 858					warnings = append(warnings, ai.CallWarning{
 859						Type:    ai.CallWarningTypeOther,
 860						Message: "system prompt text part does not have the right type",
 861					})
 862					continue
 863				}
 864				text := textPart.Text
 865				if strings.TrimSpace(text) != "" {
 866					systemPromptParts = append(systemPromptParts, textPart.Text)
 867				}
 868			}
 869			if len(systemPromptParts) == 0 {
 870				warnings = append(warnings, ai.CallWarning{
 871					Type:    ai.CallWarningTypeOther,
 872					Message: "system prompt has no text parts",
 873				})
 874				continue
 875			}
 876			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 877		case ai.MessageRoleUser:
 878			// simple user message just text content
 879			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 880				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 881				if !ok {
 882					warnings = append(warnings, ai.CallWarning{
 883						Type:    ai.CallWarningTypeOther,
 884						Message: "user message text part does not have the right type",
 885					})
 886					continue
 887				}
 888				messages = append(messages, openai.UserMessage(textPart.Text))
 889				continue
 890			}
 891			// text content and attachments
 892			// for now we only support image content later we need to check
 893			// TODO: add the supported media types to the language model so we
 894			//  can use that to validate the data here.
 895			var content []openai.ChatCompletionContentPartUnionParam
 896			for _, c := range msg.Content {
 897				switch c.GetType() {
 898				case ai.ContentTypeText:
 899					textPart, ok := ai.AsContentType[ai.TextPart](c)
 900					if !ok {
 901						warnings = append(warnings, ai.CallWarning{
 902							Type:    ai.CallWarningTypeOther,
 903							Message: "user message text part does not have the right type",
 904						})
 905						continue
 906					}
 907					content = append(content, openai.ChatCompletionContentPartUnionParam{
 908						OfText: &openai.ChatCompletionContentPartTextParam{
 909							Text: textPart.Text,
 910						},
 911					})
 912				case ai.ContentTypeFile:
 913					filePart, ok := ai.AsContentType[ai.FilePart](c)
 914					if !ok {
 915						warnings = append(warnings, ai.CallWarning{
 916							Type:    ai.CallWarningTypeOther,
 917							Message: "user message file part does not have the right type",
 918						})
 919						continue
 920					}
 921
 922					switch {
 923					case strings.HasPrefix(filePart.MediaType, "image/"):
 924						// Handle image files
 925						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 926						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 927						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 928
 929						// Check for provider-specific options like image detail
 930						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 931							if detail, ok := providerOptions["imageDetail"].(string); ok {
 932								imageURL.Detail = detail
 933							}
 934						}
 935
 936						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 937						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 938
 939					case filePart.MediaType == "audio/wav":
 940						// Handle WAV audio files
 941						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 942						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 943							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 944								Data:   base64Encoded,
 945								Format: "wav",
 946							},
 947						}
 948						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 949
 950					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 951						// Handle MP3 audio files
 952						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 953						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 954							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 955								Data:   base64Encoded,
 956								Format: "mp3",
 957							},
 958						}
 959						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 960
 961					case filePart.MediaType == "application/pdf":
 962						// Handle PDF files
 963						dataStr := string(filePart.Data)
 964
 965						// Check if data looks like a file ID (starts with "file-")
 966						if strings.HasPrefix(dataStr, "file-") {
 967							fileBlock := openai.ChatCompletionContentPartFileParam{
 968								File: openai.ChatCompletionContentPartFileFileParam{
 969									FileID: param.NewOpt(dataStr),
 970								},
 971							}
 972							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 973						} else {
 974							// Handle as base64 data
 975							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 976							data := "data:application/pdf;base64," + base64Encoded
 977
 978							filename := filePart.Filename
 979							if filename == "" {
 980								// Generate default filename based on content index
 981								filename = fmt.Sprintf("part-%d.pdf", len(content))
 982							}
 983
 984							fileBlock := openai.ChatCompletionContentPartFileParam{
 985								File: openai.ChatCompletionContentPartFileFileParam{
 986									Filename: param.NewOpt(filename),
 987									FileData: param.NewOpt(data),
 988								},
 989							}
 990							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 991						}
 992
 993					default:
 994						warnings = append(warnings, ai.CallWarning{
 995							Type:    ai.CallWarningTypeOther,
 996							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
 997						})
 998					}
 999				}
1000			}
1001			messages = append(messages, openai.UserMessage(content))
1002		case ai.MessageRoleAssistant:
1003			// simple assistant message just text content
1004			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1005				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1006				if !ok {
1007					warnings = append(warnings, ai.CallWarning{
1008						Type:    ai.CallWarningTypeOther,
1009						Message: "assistant message text part does not have the right type",
1010					})
1011					continue
1012				}
1013				messages = append(messages, openai.AssistantMessage(textPart.Text))
1014				continue
1015			}
1016			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1017				Role: "assistant",
1018			}
1019			for _, c := range msg.Content {
1020				switch c.GetType() {
1021				case ai.ContentTypeText:
1022					textPart, ok := ai.AsContentType[ai.TextPart](c)
1023					if !ok {
1024						warnings = append(warnings, ai.CallWarning{
1025							Type:    ai.CallWarningTypeOther,
1026							Message: "assistant message text part does not have the right type",
1027						})
1028						continue
1029					}
1030					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1031						OfString: param.NewOpt(textPart.Text),
1032					}
1033				case ai.ContentTypeToolCall:
1034					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1035					if !ok {
1036						warnings = append(warnings, ai.CallWarning{
1037							Type:    ai.CallWarningTypeOther,
1038							Message: "assistant message tool part does not have the right type",
1039						})
1040						continue
1041					}
1042					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1043						openai.ChatCompletionMessageToolCallUnionParam{
1044							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1045								ID:   toolCallPart.ToolCallID,
1046								Type: "function",
1047								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1048									Name:      toolCallPart.ToolName,
1049									Arguments: toolCallPart.Input,
1050								},
1051							},
1052						})
1053				}
1054			}
1055			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1056				OfAssistant: &assistantMsg,
1057			})
1058		case ai.MessageRoleTool:
1059			for _, c := range msg.Content {
1060				if c.GetType() != ai.ContentTypeToolResult {
1061					warnings = append(warnings, ai.CallWarning{
1062						Type:    ai.CallWarningTypeOther,
1063						Message: "tool message can only have tool result content",
1064					})
1065					continue
1066				}
1067
1068				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1069				if !ok {
1070					warnings = append(warnings, ai.CallWarning{
1071						Type:    ai.CallWarningTypeOther,
1072						Message: "tool message result part does not have the right type",
1073					})
1074					continue
1075				}
1076
1077				switch toolResultPart.Output.GetType() {
1078				case ai.ToolResultContentTypeText:
1079					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1080					if !ok {
1081						warnings = append(warnings, ai.CallWarning{
1082							Type:    ai.CallWarningTypeOther,
1083							Message: "tool result output does not have the right type",
1084						})
1085						continue
1086					}
1087					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1088				case ai.ToolResultContentTypeError:
1089					// TODO: check if better handling is needed
1090					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1091					if !ok {
1092						warnings = append(warnings, ai.CallWarning{
1093							Type:    ai.CallWarningTypeOther,
1094							Message: "tool result output does not have the right type",
1095						})
1096						continue
1097					}
1098					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1099				}
1100			}
1101		}
1102	}
1103	return messages, warnings
1104}
1105
1106// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1107func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1108	var annotations []openai.ChatCompletionMessageAnnotation
1109
1110	// Parse the raw JSON to extract annotations
1111	var deltaData map[string]any
1112	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1113		return annotations
1114	}
1115
1116	// Check if annotations exist in the delta
1117	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1118		for _, annotationData := range annotationsData {
1119			if annotationMap, ok := annotationData.(map[string]any); ok {
1120				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1121					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1122						annotation := openai.ChatCompletionMessageAnnotation{
1123							Type: "url_citation",
1124							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1125								URL:   urlCitationData["url"].(string),
1126								Title: urlCitationData["title"].(string),
1127							},
1128						}
1129						annotations = append(annotations, annotation)
1130					}
1131				}
1132			}
1133		}
1134	}
1135
1136	return annotations
1137}