openai.go

   1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/ai"
  15	"github.com/google/uuid"
  16	"github.com/openai/openai-go/v2"
  17	"github.com/openai/openai-go/v2/option"
  18	"github.com/openai/openai-go/v2/packages/param"
  19	"github.com/openai/openai-go/v2/shared"
  20)
  21
  22type provider struct {
  23	options options
  24}
  25
  26type options struct {
  27	baseURL      string
  28	apiKey       string
  29	organization string
  30	project      string
  31	name         string
  32	headers      map[string]string
  33	client       option.HTTPClient
  34}
  35
  36type Option = func(*options)
  37
  38func New(opts ...Option) ai.Provider {
  39	options := options{
  40		headers: map[string]string{},
  41	}
  42	for _, o := range opts {
  43		o(&options)
  44	}
  45
  46	options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1")
  47	options.name = cmp.Or(options.name, "openai")
  48
  49	if options.organization != "" {
  50		options.headers["OpenAi-Organization"] = options.organization
  51	}
  52	if options.project != "" {
  53		options.headers["OpenAi-Project"] = options.project
  54	}
  55
  56	return &provider{options: options}
  57}
  58
  59func WithBaseURL(baseURL string) Option {
  60	return func(o *options) {
  61		o.baseURL = baseURL
  62	}
  63}
  64
  65func WithAPIKey(apiKey string) Option {
  66	return func(o *options) {
  67		o.apiKey = apiKey
  68	}
  69}
  70
  71func WithOrganization(organization string) Option {
  72	return func(o *options) {
  73		o.organization = organization
  74	}
  75}
  76
  77func WithProject(project string) Option {
  78	return func(o *options) {
  79		o.project = project
  80	}
  81}
  82
  83func WithName(name string) Option {
  84	return func(o *options) {
  85		o.name = name
  86	}
  87}
  88
  89func WithHeaders(headers map[string]string) Option {
  90	return func(o *options) {
  91		maps.Copy(o.headers, headers)
  92	}
  93}
  94
  95func WithHTTPClient(client option.HTTPClient) Option {
  96	return func(o *options) {
  97		o.client = client
  98	}
  99}
 100
 101// LanguageModel implements ai.Provider.
 102func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 103	openaiClientOptions := []option.RequestOption{}
 104	if o.options.apiKey != "" {
 105		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 106	}
 107	if o.options.baseURL != "" {
 108		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 109	}
 110
 111	for key, value := range o.options.headers {
 112		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 113	}
 114
 115	if o.options.client != nil {
 116		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 117	}
 118
 119	return languageModel{
 120		modelID:  modelID,
 121		provider: fmt.Sprintf("%s.chat", o.options.name),
 122		options:  o.options,
 123		client:   openai.NewClient(openaiClientOptions...),
 124	}, nil
 125}
 126
 127type languageModel struct {
 128	provider string
 129	modelID  string
 130	client   openai.Client
 131	options  options
 132}
 133
 134// Model implements ai.LanguageModel.
 135func (o languageModel) Model() string {
 136	return o.modelID
 137}
 138
 139// Provider implements ai.LanguageModel.
 140func (o languageModel) Provider() string {
 141	return o.provider
 142}
 143
 144func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 145	params := &openai.ChatCompletionNewParams{}
 146	messages, warnings := toPrompt(call.Prompt)
 147	providerOptions := &providerOptions{}
 148	if v, ok := call.ProviderOptions["openai"]; ok {
 149		err := ai.ParseOptions(v, providerOptions)
 150		if err != nil {
 151			return nil, nil, err
 152		}
 153	}
 154	if call.TopK != nil {
 155		warnings = append(warnings, ai.CallWarning{
 156			Type:    ai.CallWarningTypeUnsupportedSetting,
 157			Setting: "top_k",
 158		})
 159	}
 160	params.Messages = messages
 161	params.Model = o.modelID
 162	if providerOptions.LogitBias != nil {
 163		params.LogitBias = providerOptions.LogitBias
 164	}
 165	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 166		providerOptions.LogProbs = nil
 167	}
 168	if providerOptions.LogProbs != nil {
 169		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 170	}
 171	if providerOptions.TopLogProbs != nil {
 172		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 173	}
 174	if providerOptions.User != nil {
 175		params.User = param.NewOpt(*providerOptions.User)
 176	}
 177	if providerOptions.ParallelToolCalls != nil {
 178		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 179	}
 180
 181	if call.MaxOutputTokens != nil {
 182		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 183	}
 184	if call.Temperature != nil {
 185		params.Temperature = param.NewOpt(*call.Temperature)
 186	}
 187	if call.TopP != nil {
 188		params.TopP = param.NewOpt(*call.TopP)
 189	}
 190	if call.FrequencyPenalty != nil {
 191		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 192	}
 193	if call.PresencePenalty != nil {
 194		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 195	}
 196
 197	if providerOptions.MaxCompletionTokens != nil {
 198		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 199	}
 200
 201	if providerOptions.TextVerbosity != nil {
 202		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 203	}
 204	if providerOptions.Prediction != nil {
 205		// Convert map[string]any to ChatCompletionPredictionContentParam
 206		if content, ok := providerOptions.Prediction["content"]; ok {
 207			if contentStr, ok := content.(string); ok {
 208				params.Prediction = openai.ChatCompletionPredictionContentParam{
 209					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 210						OfString: param.NewOpt(contentStr),
 211					},
 212				}
 213			}
 214		}
 215	}
 216	if providerOptions.Store != nil {
 217		params.Store = param.NewOpt(*providerOptions.Store)
 218	}
 219	if providerOptions.Metadata != nil {
 220		// Convert map[string]any to map[string]string
 221		metadata := make(map[string]string)
 222		for k, v := range providerOptions.Metadata {
 223			if str, ok := v.(string); ok {
 224				metadata[k] = str
 225			}
 226		}
 227		params.Metadata = metadata
 228	}
 229	if providerOptions.PromptCacheKey != nil {
 230		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 231	}
 232	if providerOptions.SafetyIdentifier != nil {
 233		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 234	}
 235	if providerOptions.ServiceTier != nil {
 236		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 237	}
 238
 239	if providerOptions.ReasoningEffort != nil {
 240		switch *providerOptions.ReasoningEffort {
 241		case reasoningEffortMinimal:
 242			params.ReasoningEffort = shared.ReasoningEffortMinimal
 243		case reasoningEffortLow:
 244			params.ReasoningEffort = shared.ReasoningEffortLow
 245		case reasoningEffortMedium:
 246			params.ReasoningEffort = shared.ReasoningEffortMedium
 247		case reasoningEffortHigh:
 248			params.ReasoningEffort = shared.ReasoningEffortHigh
 249		default:
 250			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 251		}
 252	}
 253
 254	if isReasoningModel(o.modelID) {
 255		// remove unsupported settings for reasoning models
 256		// see https://platform.openai.com/docs/guides/reasoning#limitations
 257		if call.Temperature != nil {
 258			params.Temperature = param.Opt[float64]{}
 259			warnings = append(warnings, ai.CallWarning{
 260				Type:    ai.CallWarningTypeUnsupportedSetting,
 261				Setting: "temperature",
 262				Details: "temperature is not supported for reasoning models",
 263			})
 264		}
 265		if call.TopP != nil {
 266			params.TopP = param.Opt[float64]{}
 267			warnings = append(warnings, ai.CallWarning{
 268				Type:    ai.CallWarningTypeUnsupportedSetting,
 269				Setting: "TopP",
 270				Details: "TopP is not supported for reasoning models",
 271			})
 272		}
 273		if call.FrequencyPenalty != nil {
 274			params.FrequencyPenalty = param.Opt[float64]{}
 275			warnings = append(warnings, ai.CallWarning{
 276				Type:    ai.CallWarningTypeUnsupportedSetting,
 277				Setting: "FrequencyPenalty",
 278				Details: "FrequencyPenalty is not supported for reasoning models",
 279			})
 280		}
 281		if call.PresencePenalty != nil {
 282			params.PresencePenalty = param.Opt[float64]{}
 283			warnings = append(warnings, ai.CallWarning{
 284				Type:    ai.CallWarningTypeUnsupportedSetting,
 285				Setting: "PresencePenalty",
 286				Details: "PresencePenalty is not supported for reasoning models",
 287			})
 288		}
 289		if providerOptions.LogitBias != nil {
 290			params.LogitBias = nil
 291			warnings = append(warnings, ai.CallWarning{
 292				Type:    ai.CallWarningTypeUnsupportedSetting,
 293				Setting: "LogitBias",
 294				Message: "LogitBias is not supported for reasoning models",
 295			})
 296		}
 297		if providerOptions.LogProbs != nil {
 298			params.Logprobs = param.Opt[bool]{}
 299			warnings = append(warnings, ai.CallWarning{
 300				Type:    ai.CallWarningTypeUnsupportedSetting,
 301				Setting: "Logprobs",
 302				Message: "Logprobs is not supported for reasoning models",
 303			})
 304		}
 305		if providerOptions.TopLogProbs != nil {
 306			params.TopLogprobs = param.Opt[int64]{}
 307			warnings = append(warnings, ai.CallWarning{
 308				Type:    ai.CallWarningTypeUnsupportedSetting,
 309				Setting: "TopLogprobs",
 310				Message: "TopLogprobs is not supported for reasoning models",
 311			})
 312		}
 313
 314		// reasoning models use max_completion_tokens instead of max_tokens
 315		if call.MaxOutputTokens != nil {
 316			if providerOptions.MaxCompletionTokens == nil {
 317				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 318			}
 319			params.MaxTokens = param.Opt[int64]{}
 320		}
 321	}
 322
 323	// Handle search preview models
 324	if isSearchPreviewModel(o.modelID) {
 325		if call.Temperature != nil {
 326			params.Temperature = param.Opt[float64]{}
 327			warnings = append(warnings, ai.CallWarning{
 328				Type:    ai.CallWarningTypeUnsupportedSetting,
 329				Setting: "temperature",
 330				Details: "temperature is not supported for the search preview models and has been removed.",
 331			})
 332		}
 333	}
 334
 335	// Handle service tier validation
 336	if providerOptions.ServiceTier != nil {
 337		serviceTier := *providerOptions.ServiceTier
 338		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 339			params.ServiceTier = ""
 340			warnings = append(warnings, ai.CallWarning{
 341				Type:    ai.CallWarningTypeUnsupportedSetting,
 342				Setting: "ServiceTier",
 343				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 344			})
 345		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 346			params.ServiceTier = ""
 347			warnings = append(warnings, ai.CallWarning{
 348				Type:    ai.CallWarningTypeUnsupportedSetting,
 349				Setting: "ServiceTier",
 350				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 351			})
 352		}
 353	}
 354
 355	if len(call.Tools) > 0 {
 356		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 357		params.Tools = tools
 358		if toolChoice != nil {
 359			params.ToolChoice = *toolChoice
 360		}
 361		warnings = append(warnings, toolWarnings...)
 362	}
 363	return params, warnings, nil
 364}
 365
 366func (o languageModel) handleError(err error) error {
 367	var apiErr *openai.Error
 368	if errors.As(err, &apiErr) {
 369		requestDump := apiErr.DumpRequest(true)
 370		responseDump := apiErr.DumpResponse(true)
 371		headers := map[string]string{}
 372		for k, h := range apiErr.Response.Header {
 373			v := h[len(h)-1]
 374			headers[strings.ToLower(k)] = v
 375		}
 376		return ai.NewAPICallError(
 377			apiErr.Message,
 378			apiErr.Request.URL.String(),
 379			string(requestDump),
 380			apiErr.StatusCode,
 381			headers,
 382			string(responseDump),
 383			apiErr,
 384			false,
 385		)
 386	}
 387	return err
 388}
 389
 390// Generate implements ai.LanguageModel.
 391func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 392	params, warnings, err := o.prepareParams(call)
 393	if err != nil {
 394		return nil, err
 395	}
 396	response, err := o.client.Chat.Completions.New(ctx, *params)
 397	if err != nil {
 398		return nil, o.handleError(err)
 399	}
 400
 401	if len(response.Choices) == 0 {
 402		return nil, errors.New("no response generated")
 403	}
 404	choice := response.Choices[0]
 405	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 406	text := choice.Message.Content
 407	if text != "" {
 408		content = append(content, ai.TextContent{
 409			Text: text,
 410		})
 411	}
 412
 413	for _, tc := range choice.Message.ToolCalls {
 414		toolCallID := tc.ID
 415		if toolCallID == "" {
 416			toolCallID = uuid.NewString()
 417		}
 418		content = append(content, ai.ToolCallContent{
 419			ProviderExecuted: false, // TODO: update when handling other tools
 420			ToolCallID:       toolCallID,
 421			ToolName:         tc.Function.Name,
 422			Input:            tc.Function.Arguments,
 423		})
 424	}
 425	// Handle annotations/citations
 426	for _, annotation := range choice.Message.Annotations {
 427		if annotation.Type == "url_citation" {
 428			content = append(content, ai.SourceContent{
 429				SourceType: ai.SourceTypeURL,
 430				ID:         uuid.NewString(),
 431				URL:        annotation.URLCitation.URL,
 432				Title:      annotation.URLCitation.Title,
 433			})
 434		}
 435	}
 436
 437	completionTokenDetails := response.Usage.CompletionTokensDetails
 438	promptTokenDetails := response.Usage.PromptTokensDetails
 439
 440	// Build provider metadata
 441	providerMetadata := ai.ProviderMetadata{
 442		"openai": make(map[string]any),
 443	}
 444
 445	// Add logprobs if available
 446	if len(choice.Logprobs.Content) > 0 {
 447		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
 448	}
 449
 450	// Add prediction tokens if available
 451	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 452		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 453			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 454		}
 455		if completionTokenDetails.RejectedPredictionTokens > 0 {
 456			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 457		}
 458	}
 459
 460	return &ai.Response{
 461		Content: content,
 462		Usage: ai.Usage{
 463			InputTokens:     response.Usage.PromptTokens,
 464			OutputTokens:    response.Usage.CompletionTokens,
 465			TotalTokens:     response.Usage.TotalTokens,
 466			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 467			CacheReadTokens: promptTokenDetails.CachedTokens,
 468		},
 469		FinishReason:     mapOpenAiFinishReason(choice.FinishReason),
 470		ProviderMetadata: providerMetadata,
 471		Warnings:         warnings,
 472	}, nil
 473}
 474
 475type toolCall struct {
 476	id          string
 477	name        string
 478	arguments   string
 479	hasFinished bool
 480}
 481
 482// Stream implements ai.LanguageModel.
 483func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 484	params, warnings, err := o.prepareParams(call)
 485	if err != nil {
 486		return nil, err
 487	}
 488
 489	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 490		IncludeUsage: openai.Bool(true),
 491	}
 492
 493	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 494	isActiveText := false
 495	toolCalls := make(map[int64]toolCall)
 496
 497	// Build provider metadata for streaming
 498	streamProviderMetadata := ai.ProviderMetadata{
 499		"openai": make(map[string]any),
 500	}
 501
 502	acc := openai.ChatCompletionAccumulator{}
 503	var usage ai.Usage
 504	return func(yield func(ai.StreamPart) bool) {
 505		if len(warnings) > 0 {
 506			if !yield(ai.StreamPart{
 507				Type:     ai.StreamPartTypeWarnings,
 508				Warnings: warnings,
 509			}) {
 510				return
 511			}
 512		}
 513		for stream.Next() {
 514			chunk := stream.Current()
 515			acc.AddChunk(chunk)
 516			if chunk.Usage.TotalTokens > 0 {
 517				// we do this here because the acc does not add prompt details
 518				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 519				promptTokenDetails := chunk.Usage.PromptTokensDetails
 520				usage = ai.Usage{
 521					InputTokens:     chunk.Usage.PromptTokens,
 522					OutputTokens:    chunk.Usage.CompletionTokens,
 523					TotalTokens:     chunk.Usage.TotalTokens,
 524					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 525					CacheReadTokens: promptTokenDetails.CachedTokens,
 526				}
 527
 528				// Add prediction tokens if available
 529				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 530					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 531						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 532					}
 533					if completionTokenDetails.RejectedPredictionTokens > 0 {
 534						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 535					}
 536				}
 537			}
 538			if len(chunk.Choices) == 0 {
 539				continue
 540			}
 541			for _, choice := range chunk.Choices {
 542				switch {
 543				case choice.Delta.Content != "":
 544					if !isActiveText {
 545						isActiveText = true
 546						if !yield(ai.StreamPart{
 547							Type: ai.StreamPartTypeTextStart,
 548							ID:   "0",
 549						}) {
 550							return
 551						}
 552					}
 553					if !yield(ai.StreamPart{
 554						Type:  ai.StreamPartTypeTextDelta,
 555						ID:    "0",
 556						Delta: choice.Delta.Content,
 557					}) {
 558						return
 559					}
 560				case len(choice.Delta.ToolCalls) > 0:
 561					if isActiveText {
 562						isActiveText = false
 563						if !yield(ai.StreamPart{
 564							Type: ai.StreamPartTypeTextEnd,
 565							ID:   "0",
 566						}) {
 567							return
 568						}
 569					}
 570
 571					for _, toolCallDelta := range choice.Delta.ToolCalls {
 572						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 573							if existingToolCall.hasFinished {
 574								continue
 575							}
 576							if toolCallDelta.Function.Arguments != "" {
 577								existingToolCall.arguments += toolCallDelta.Function.Arguments
 578							}
 579							if !yield(ai.StreamPart{
 580								Type:  ai.StreamPartTypeToolInputDelta,
 581								ID:    existingToolCall.id,
 582								Delta: toolCallDelta.Function.Arguments,
 583							}) {
 584								return
 585							}
 586							toolCalls[toolCallDelta.Index] = existingToolCall
 587							if isValidJSON(existingToolCall.arguments) {
 588								if !yield(ai.StreamPart{
 589									Type: ai.StreamPartTypeToolInputEnd,
 590									ID:   existingToolCall.id,
 591								}) {
 592									return
 593								}
 594
 595								if !yield(ai.StreamPart{
 596									Type:          ai.StreamPartTypeToolCall,
 597									ID:            existingToolCall.id,
 598									ToolCallName:  existingToolCall.name,
 599									ToolCallInput: existingToolCall.arguments,
 600								}) {
 601									return
 602								}
 603								existingToolCall.hasFinished = true
 604								toolCalls[toolCallDelta.Index] = existingToolCall
 605							}
 606						} else {
 607							// Does not exist
 608							var err error
 609							if toolCallDelta.Type != "function" {
 610								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 611							}
 612							if toolCallDelta.ID == "" {
 613								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 614							}
 615							if toolCallDelta.Function.Name == "" {
 616								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 617							}
 618							if err != nil {
 619								yield(ai.StreamPart{
 620									Type:  ai.StreamPartTypeError,
 621									Error: o.handleError(stream.Err()),
 622								})
 623								return
 624							}
 625
 626							if !yield(ai.StreamPart{
 627								Type:         ai.StreamPartTypeToolInputStart,
 628								ID:           toolCallDelta.ID,
 629								ToolCallName: toolCallDelta.Function.Name,
 630							}) {
 631								return
 632							}
 633							toolCalls[toolCallDelta.Index] = toolCall{
 634								id:        toolCallDelta.ID,
 635								name:      toolCallDelta.Function.Name,
 636								arguments: toolCallDelta.Function.Arguments,
 637							}
 638
 639							exTc := toolCalls[toolCallDelta.Index]
 640							if exTc.arguments != "" {
 641								if !yield(ai.StreamPart{
 642									Type:  ai.StreamPartTypeToolInputDelta,
 643									ID:    exTc.id,
 644									Delta: exTc.arguments,
 645								}) {
 646									return
 647								}
 648								if isValidJSON(toolCalls[toolCallDelta.Index].arguments) {
 649									if !yield(ai.StreamPart{
 650										Type: ai.StreamPartTypeToolInputEnd,
 651										ID:   toolCallDelta.ID,
 652									}) {
 653										return
 654									}
 655
 656									if !yield(ai.StreamPart{
 657										Type:          ai.StreamPartTypeToolCall,
 658										ID:            exTc.id,
 659										ToolCallName:  exTc.name,
 660										ToolCallInput: exTc.arguments,
 661									}) {
 662										return
 663									}
 664									exTc.hasFinished = true
 665									toolCalls[toolCallDelta.Index] = exTc
 666								}
 667							}
 668							continue
 669						}
 670					}
 671				}
 672			}
 673
 674			// Check for annotations in the delta's raw JSON
 675			for _, choice := range chunk.Choices {
 676				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 677					for _, annotation := range annotations {
 678						if annotation.Type == "url_citation" {
 679							if !yield(ai.StreamPart{
 680								Type:       ai.StreamPartTypeSource,
 681								ID:         uuid.NewString(),
 682								SourceType: ai.SourceTypeURL,
 683								URL:        annotation.URLCitation.URL,
 684								Title:      annotation.URLCitation.Title,
 685							}) {
 686								return
 687							}
 688						}
 689					}
 690				}
 691			}
 692		}
 693		err := stream.Err()
 694		if err == nil || errors.Is(err, io.EOF) {
 695			// finished
 696			if isActiveText {
 697				isActiveText = false
 698				if !yield(ai.StreamPart{
 699					Type: ai.StreamPartTypeTextEnd,
 700					ID:   "0",
 701				}) {
 702					return
 703				}
 704			}
 705
 706			// Add logprobs if available
 707			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 708				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
 709			}
 710
 711			// Handle annotations/citations from accumulated response
 712			if len(acc.Choices) > 0 {
 713				for _, annotation := range acc.Choices[0].Message.Annotations {
 714					if annotation.Type == "url_citation" {
 715						if !yield(ai.StreamPart{
 716							Type:       ai.StreamPartTypeSource,
 717							ID:         acc.ID,
 718							SourceType: ai.SourceTypeURL,
 719							URL:        annotation.URLCitation.URL,
 720							Title:      annotation.URLCitation.Title,
 721						}) {
 722							return
 723						}
 724					}
 725				}
 726			}
 727
 728			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 729			yield(ai.StreamPart{
 730				Type:             ai.StreamPartTypeFinish,
 731				Usage:            usage,
 732				FinishReason:     finishReason,
 733				ProviderMetadata: streamProviderMetadata,
 734			})
 735			return
 736		} else {
 737			yield(ai.StreamPart{
 738				Type:  ai.StreamPartTypeError,
 739				Error: o.handleError(err),
 740			})
 741			return
 742		}
 743	}, nil
 744}
 745
 746func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 747	switch finishReason {
 748	case "stop":
 749		return ai.FinishReasonStop
 750	case "length":
 751		return ai.FinishReasonLength
 752	case "content_filter":
 753		return ai.FinishReasonContentFilter
 754	case "function_call", "tool_calls":
 755		return ai.FinishReasonToolCalls
 756	default:
 757		return ai.FinishReasonUnknown
 758	}
 759}
 760
 761func isReasoningModel(modelID string) bool {
 762	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 763}
 764
 765func isSearchPreviewModel(modelID string) bool {
 766	return strings.Contains(modelID, "search-preview")
 767}
 768
 769func supportsFlexProcessing(modelID string) bool {
 770	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 771}
 772
 773func supportsPriorityProcessing(modelID string) bool {
 774	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 775		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 776		strings.HasPrefix(modelID, "o4-mini")
 777}
 778
 779func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 780	for _, tool := range tools {
 781		if tool.GetType() == ai.ToolTypeFunction {
 782			ft, ok := tool.(ai.FunctionTool)
 783			if !ok {
 784				continue
 785			}
 786			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 787				OfFunction: &openai.ChatCompletionFunctionToolParam{
 788					Function: shared.FunctionDefinitionParam{
 789						Name:        ft.Name,
 790						Description: param.NewOpt(ft.Description),
 791						Parameters:  openai.FunctionParameters(ft.InputSchema),
 792						Strict:      param.NewOpt(false),
 793					},
 794					Type: "function",
 795				},
 796			})
 797			continue
 798		}
 799
 800		// TODO: handle provider tool calls
 801		warnings = append(warnings, ai.CallWarning{
 802			Type:    ai.CallWarningTypeUnsupportedTool,
 803			Tool:    tool,
 804			Message: "tool is not supported",
 805		})
 806	}
 807	if toolChoice == nil {
 808		return openAiTools, openAiToolChoice, warnings
 809	}
 810
 811	switch *toolChoice {
 812	case ai.ToolChoiceAuto:
 813		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 814			OfAuto: param.NewOpt("auto"),
 815		}
 816	case ai.ToolChoiceNone:
 817		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 818			OfAuto: param.NewOpt("none"),
 819		}
 820	default:
 821		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 822			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 823				Type: "function",
 824				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 825					Name: string(*toolChoice),
 826				},
 827			},
 828		}
 829	}
 830	return openAiTools, openAiToolChoice, warnings
 831}
 832
 833func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 834	var messages []openai.ChatCompletionMessageParamUnion
 835	var warnings []ai.CallWarning
 836	for _, msg := range prompt {
 837		switch msg.Role {
 838		case ai.MessageRoleSystem:
 839			var systemPromptParts []string
 840			for _, c := range msg.Content {
 841				if c.GetType() != ai.ContentTypeText {
 842					warnings = append(warnings, ai.CallWarning{
 843						Type:    ai.CallWarningTypeOther,
 844						Message: "system prompt can only have text content",
 845					})
 846					continue
 847				}
 848				textPart, ok := ai.AsContentType[ai.TextPart](c)
 849				if !ok {
 850					warnings = append(warnings, ai.CallWarning{
 851						Type:    ai.CallWarningTypeOther,
 852						Message: "system prompt text part does not have the right type",
 853					})
 854					continue
 855				}
 856				text := textPart.Text
 857				if strings.TrimSpace(text) != "" {
 858					systemPromptParts = append(systemPromptParts, textPart.Text)
 859				}
 860			}
 861			if len(systemPromptParts) == 0 {
 862				warnings = append(warnings, ai.CallWarning{
 863					Type:    ai.CallWarningTypeOther,
 864					Message: "system prompt has no text parts",
 865				})
 866				continue
 867			}
 868			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 869		case ai.MessageRoleUser:
 870			// simple user message just text content
 871			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 872				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 873				if !ok {
 874					warnings = append(warnings, ai.CallWarning{
 875						Type:    ai.CallWarningTypeOther,
 876						Message: "user message text part does not have the right type",
 877					})
 878					continue
 879				}
 880				messages = append(messages, openai.UserMessage(textPart.Text))
 881				continue
 882			}
 883			// text content and attachments
 884			// for now we only support image content later we need to check
 885			// TODO: add the supported media types to the language model so we
 886			//  can use that to validate the data here.
 887			var content []openai.ChatCompletionContentPartUnionParam
 888			for _, c := range msg.Content {
 889				switch c.GetType() {
 890				case ai.ContentTypeText:
 891					textPart, ok := ai.AsContentType[ai.TextPart](c)
 892					if !ok {
 893						warnings = append(warnings, ai.CallWarning{
 894							Type:    ai.CallWarningTypeOther,
 895							Message: "user message text part does not have the right type",
 896						})
 897						continue
 898					}
 899					content = append(content, openai.ChatCompletionContentPartUnionParam{
 900						OfText: &openai.ChatCompletionContentPartTextParam{
 901							Text: textPart.Text,
 902						},
 903					})
 904				case ai.ContentTypeFile:
 905					filePart, ok := ai.AsContentType[ai.FilePart](c)
 906					if !ok {
 907						warnings = append(warnings, ai.CallWarning{
 908							Type:    ai.CallWarningTypeOther,
 909							Message: "user message file part does not have the right type",
 910						})
 911						continue
 912					}
 913
 914					switch {
 915					case strings.HasPrefix(filePart.MediaType, "image/"):
 916						// Handle image files
 917						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 918						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 919						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 920
 921						// Check for provider-specific options like image detail
 922						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 923							if detail, ok := providerOptions["imageDetail"].(string); ok {
 924								imageURL.Detail = detail
 925							}
 926						}
 927
 928						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 929						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 930
 931					case filePart.MediaType == "audio/wav":
 932						// Handle WAV audio files
 933						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 934						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 935							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 936								Data:   base64Encoded,
 937								Format: "wav",
 938							},
 939						}
 940						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 941
 942					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 943						// Handle MP3 audio files
 944						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 945						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 946							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 947								Data:   base64Encoded,
 948								Format: "mp3",
 949							},
 950						}
 951						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 952
 953					case filePart.MediaType == "application/pdf":
 954						// Handle PDF files
 955						dataStr := string(filePart.Data)
 956
 957						// Check if data looks like a file ID (starts with "file-")
 958						if strings.HasPrefix(dataStr, "file-") {
 959							fileBlock := openai.ChatCompletionContentPartFileParam{
 960								File: openai.ChatCompletionContentPartFileFileParam{
 961									FileID: param.NewOpt(dataStr),
 962								},
 963							}
 964							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 965						} else {
 966							// Handle as base64 data
 967							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 968							data := "data:application/pdf;base64," + base64Encoded
 969
 970							filename := filePart.Filename
 971							if filename == "" {
 972								// Generate default filename based on content index
 973								filename = fmt.Sprintf("part-%d.pdf", len(content))
 974							}
 975
 976							fileBlock := openai.ChatCompletionContentPartFileParam{
 977								File: openai.ChatCompletionContentPartFileFileParam{
 978									Filename: param.NewOpt(filename),
 979									FileData: param.NewOpt(data),
 980								},
 981							}
 982							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 983						}
 984
 985					default:
 986						warnings = append(warnings, ai.CallWarning{
 987							Type:    ai.CallWarningTypeOther,
 988							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
 989						})
 990					}
 991				}
 992			}
 993			messages = append(messages, openai.UserMessage(content))
 994		case ai.MessageRoleAssistant:
 995			// simple assistant message just text content
 996			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 997				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 998				if !ok {
 999					warnings = append(warnings, ai.CallWarning{
1000						Type:    ai.CallWarningTypeOther,
1001						Message: "assistant message text part does not have the right type",
1002					})
1003					continue
1004				}
1005				messages = append(messages, openai.AssistantMessage(textPart.Text))
1006				continue
1007			}
1008			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1009				Role: "assistant",
1010			}
1011			for _, c := range msg.Content {
1012				switch c.GetType() {
1013				case ai.ContentTypeText:
1014					textPart, ok := ai.AsContentType[ai.TextPart](c)
1015					if !ok {
1016						warnings = append(warnings, ai.CallWarning{
1017							Type:    ai.CallWarningTypeOther,
1018							Message: "assistant message text part does not have the right type",
1019						})
1020						continue
1021					}
1022					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1023						OfString: param.NewOpt(textPart.Text),
1024					}
1025				case ai.ContentTypeToolCall:
1026					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1027					if !ok {
1028						warnings = append(warnings, ai.CallWarning{
1029							Type:    ai.CallWarningTypeOther,
1030							Message: "assistant message tool part does not have the right type",
1031						})
1032						continue
1033					}
1034					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1035						openai.ChatCompletionMessageToolCallUnionParam{
1036							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1037								ID:   toolCallPart.ToolCallID,
1038								Type: "function",
1039								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1040									Name:      toolCallPart.ToolName,
1041									Arguments: toolCallPart.Input,
1042								},
1043							},
1044						})
1045				}
1046			}
1047			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1048				OfAssistant: &assistantMsg,
1049			})
1050		case ai.MessageRoleTool:
1051			for _, c := range msg.Content {
1052				if c.GetType() != ai.ContentTypeToolResult {
1053					warnings = append(warnings, ai.CallWarning{
1054						Type:    ai.CallWarningTypeOther,
1055						Message: "tool message can only have tool result content",
1056					})
1057					continue
1058				}
1059
1060				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1061				if !ok {
1062					warnings = append(warnings, ai.CallWarning{
1063						Type:    ai.CallWarningTypeOther,
1064						Message: "tool message result part does not have the right type",
1065					})
1066					continue
1067				}
1068
1069				switch toolResultPart.Output.GetType() {
1070				case ai.ToolResultContentTypeText:
1071					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1072					if !ok {
1073						warnings = append(warnings, ai.CallWarning{
1074							Type:    ai.CallWarningTypeOther,
1075							Message: "tool result output does not have the right type",
1076						})
1077						continue
1078					}
1079					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1080				case ai.ToolResultContentTypeError:
1081					// TODO: check if better handling is needed
1082					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1083					if !ok {
1084						warnings = append(warnings, ai.CallWarning{
1085							Type:    ai.CallWarningTypeOther,
1086							Message: "tool result output does not have the right type",
1087						})
1088						continue
1089					}
1090					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1091				}
1092			}
1093		}
1094	}
1095	return messages, warnings
1096}
1097
1098// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1099func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1100	var annotations []openai.ChatCompletionMessageAnnotation
1101
1102	// Parse the raw JSON to extract annotations
1103	var deltaData map[string]any
1104	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1105		return annotations
1106	}
1107
1108	// Check if annotations exist in the delta
1109	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1110		for _, annotationData := range annotationsData {
1111			if annotationMap, ok := annotationData.(map[string]any); ok {
1112				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1113					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1114						annotation := openai.ChatCompletionMessageAnnotation{
1115							Type: "url_citation",
1116							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1117								URL:   urlCitationData["url"].(string),
1118								Title: urlCitationData["title"].(string),
1119							},
1120						}
1121						annotations = append(annotations, annotation)
1122					}
1123				}
1124			}
1125		}
1126	}
1127
1128	return annotations
1129}