openai.go

   1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/ai/ai"
  15	xjson "github.com/charmbracelet/x/json"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23type provider struct {
  24	options options
  25}
  26
  27type options struct {
  28	baseURL      string
  29	apiKey       string
  30	organization string
  31	project      string
  32	name         string
  33	headers      map[string]string
  34	client       option.HTTPClient
  35}
  36
  37type Option = func(*options)
  38
  39func New(opts ...Option) ai.Provider {
  40	options := options{
  41		headers: map[string]string{},
  42	}
  43	for _, o := range opts {
  44		o(&options)
  45	}
  46
  47	options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1")
  48	options.name = cmp.Or(options.name, "openai")
  49
  50	if options.organization != "" {
  51		options.headers["OpenAi-Organization"] = options.organization
  52	}
  53	if options.project != "" {
  54		options.headers["OpenAi-Project"] = options.project
  55	}
  56
  57	return &provider{options: options}
  58}
  59
  60func WithBaseURL(baseURL string) Option {
  61	return func(o *options) {
  62		o.baseURL = baseURL
  63	}
  64}
  65
  66func WithAPIKey(apiKey string) Option {
  67	return func(o *options) {
  68		o.apiKey = apiKey
  69	}
  70}
  71
  72func WithOrganization(organization string) Option {
  73	return func(o *options) {
  74		o.organization = organization
  75	}
  76}
  77
  78func WithProject(project string) Option {
  79	return func(o *options) {
  80		o.project = project
  81	}
  82}
  83
  84func WithName(name string) Option {
  85	return func(o *options) {
  86		o.name = name
  87	}
  88}
  89
  90func WithHeaders(headers map[string]string) Option {
  91	return func(o *options) {
  92		maps.Copy(o.headers, headers)
  93	}
  94}
  95
  96func WithHTTPClient(client option.HTTPClient) Option {
  97	return func(o *options) {
  98		o.client = client
  99	}
 100}
 101
 102// LanguageModel implements ai.Provider.
 103func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 104	openaiClientOptions := []option.RequestOption{}
 105	if o.options.apiKey != "" {
 106		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 107	}
 108	if o.options.baseURL != "" {
 109		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 110	}
 111
 112	for key, value := range o.options.headers {
 113		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 114	}
 115
 116	if o.options.client != nil {
 117		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 118	}
 119
 120	return languageModel{
 121		modelID:  modelID,
 122		provider: fmt.Sprintf("%s.chat", o.options.name),
 123		options:  o.options,
 124		client:   openai.NewClient(openaiClientOptions...),
 125	}, nil
 126}
 127
 128type languageModel struct {
 129	provider string
 130	modelID  string
 131	client   openai.Client
 132	options  options
 133}
 134
 135// Model implements ai.LanguageModel.
 136func (o languageModel) Model() string {
 137	return o.modelID
 138}
 139
 140// Provider implements ai.LanguageModel.
 141func (o languageModel) Provider() string {
 142	return o.provider
 143}
 144
 145func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 146	params := &openai.ChatCompletionNewParams{}
 147	messages, warnings := toPrompt(call.Prompt)
 148	providerOptions := &providerOptions{}
 149	if v, ok := call.ProviderOptions["openai"]; ok {
 150		err := ai.ParseOptions(v, providerOptions)
 151		if err != nil {
 152			return nil, nil, err
 153		}
 154	}
 155	if call.TopK != nil {
 156		warnings = append(warnings, ai.CallWarning{
 157			Type:    ai.CallWarningTypeUnsupportedSetting,
 158			Setting: "top_k",
 159		})
 160	}
 161	params.Messages = messages
 162	params.Model = o.modelID
 163	if providerOptions.LogitBias != nil {
 164		params.LogitBias = providerOptions.LogitBias
 165	}
 166	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 167		providerOptions.LogProbs = nil
 168	}
 169	if providerOptions.LogProbs != nil {
 170		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 171	}
 172	if providerOptions.TopLogProbs != nil {
 173		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 174	}
 175	if providerOptions.User != nil {
 176		params.User = param.NewOpt(*providerOptions.User)
 177	}
 178	if providerOptions.ParallelToolCalls != nil {
 179		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 180	}
 181
 182	if call.MaxOutputTokens != nil {
 183		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 184	}
 185	if call.Temperature != nil {
 186		params.Temperature = param.NewOpt(*call.Temperature)
 187	}
 188	if call.TopP != nil {
 189		params.TopP = param.NewOpt(*call.TopP)
 190	}
 191	if call.FrequencyPenalty != nil {
 192		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 193	}
 194	if call.PresencePenalty != nil {
 195		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 196	}
 197
 198	if providerOptions.MaxCompletionTokens != nil {
 199		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 200	}
 201
 202	if providerOptions.TextVerbosity != nil {
 203		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 204	}
 205	if providerOptions.Prediction != nil {
 206		// Convert map[string]any to ChatCompletionPredictionContentParam
 207		if content, ok := providerOptions.Prediction["content"]; ok {
 208			if contentStr, ok := content.(string); ok {
 209				params.Prediction = openai.ChatCompletionPredictionContentParam{
 210					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 211						OfString: param.NewOpt(contentStr),
 212					},
 213				}
 214			}
 215		}
 216	}
 217	if providerOptions.Store != nil {
 218		params.Store = param.NewOpt(*providerOptions.Store)
 219	}
 220	if providerOptions.Metadata != nil {
 221		// Convert map[string]any to map[string]string
 222		metadata := make(map[string]string)
 223		for k, v := range providerOptions.Metadata {
 224			if str, ok := v.(string); ok {
 225				metadata[k] = str
 226			}
 227		}
 228		params.Metadata = metadata
 229	}
 230	if providerOptions.PromptCacheKey != nil {
 231		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 232	}
 233	if providerOptions.SafetyIdentifier != nil {
 234		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 235	}
 236	if providerOptions.ServiceTier != nil {
 237		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 238	}
 239
 240	if providerOptions.ReasoningEffort != nil {
 241		switch *providerOptions.ReasoningEffort {
 242		case reasoningEffortMinimal:
 243			params.ReasoningEffort = shared.ReasoningEffortMinimal
 244		case reasoningEffortLow:
 245			params.ReasoningEffort = shared.ReasoningEffortLow
 246		case reasoningEffortMedium:
 247			params.ReasoningEffort = shared.ReasoningEffortMedium
 248		case reasoningEffortHigh:
 249			params.ReasoningEffort = shared.ReasoningEffortHigh
 250		default:
 251			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 252		}
 253	}
 254
 255	if isReasoningModel(o.modelID) {
 256		// remove unsupported settings for reasoning models
 257		// see https://platform.openai.com/docs/guides/reasoning#limitations
 258		if call.Temperature != nil {
 259			params.Temperature = param.Opt[float64]{}
 260			warnings = append(warnings, ai.CallWarning{
 261				Type:    ai.CallWarningTypeUnsupportedSetting,
 262				Setting: "temperature",
 263				Details: "temperature is not supported for reasoning models",
 264			})
 265		}
 266		if call.TopP != nil {
 267			params.TopP = param.Opt[float64]{}
 268			warnings = append(warnings, ai.CallWarning{
 269				Type:    ai.CallWarningTypeUnsupportedSetting,
 270				Setting: "TopP",
 271				Details: "TopP is not supported for reasoning models",
 272			})
 273		}
 274		if call.FrequencyPenalty != nil {
 275			params.FrequencyPenalty = param.Opt[float64]{}
 276			warnings = append(warnings, ai.CallWarning{
 277				Type:    ai.CallWarningTypeUnsupportedSetting,
 278				Setting: "FrequencyPenalty",
 279				Details: "FrequencyPenalty is not supported for reasoning models",
 280			})
 281		}
 282		if call.PresencePenalty != nil {
 283			params.PresencePenalty = param.Opt[float64]{}
 284			warnings = append(warnings, ai.CallWarning{
 285				Type:    ai.CallWarningTypeUnsupportedSetting,
 286				Setting: "PresencePenalty",
 287				Details: "PresencePenalty is not supported for reasoning models",
 288			})
 289		}
 290		if providerOptions.LogitBias != nil {
 291			params.LogitBias = nil
 292			warnings = append(warnings, ai.CallWarning{
 293				Type:    ai.CallWarningTypeUnsupportedSetting,
 294				Setting: "LogitBias",
 295				Message: "LogitBias is not supported for reasoning models",
 296			})
 297		}
 298		if providerOptions.LogProbs != nil {
 299			params.Logprobs = param.Opt[bool]{}
 300			warnings = append(warnings, ai.CallWarning{
 301				Type:    ai.CallWarningTypeUnsupportedSetting,
 302				Setting: "Logprobs",
 303				Message: "Logprobs is not supported for reasoning models",
 304			})
 305		}
 306		if providerOptions.TopLogProbs != nil {
 307			params.TopLogprobs = param.Opt[int64]{}
 308			warnings = append(warnings, ai.CallWarning{
 309				Type:    ai.CallWarningTypeUnsupportedSetting,
 310				Setting: "TopLogprobs",
 311				Message: "TopLogprobs is not supported for reasoning models",
 312			})
 313		}
 314
 315		// reasoning models use max_completion_tokens instead of max_tokens
 316		if call.MaxOutputTokens != nil {
 317			if providerOptions.MaxCompletionTokens == nil {
 318				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 319			}
 320			params.MaxTokens = param.Opt[int64]{}
 321		}
 322	}
 323
 324	// Handle search preview models
 325	if isSearchPreviewModel(o.modelID) {
 326		if call.Temperature != nil {
 327			params.Temperature = param.Opt[float64]{}
 328			warnings = append(warnings, ai.CallWarning{
 329				Type:    ai.CallWarningTypeUnsupportedSetting,
 330				Setting: "temperature",
 331				Details: "temperature is not supported for the search preview models and has been removed.",
 332			})
 333		}
 334	}
 335
 336	// Handle service tier validation
 337	if providerOptions.ServiceTier != nil {
 338		serviceTier := *providerOptions.ServiceTier
 339		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 340			params.ServiceTier = ""
 341			warnings = append(warnings, ai.CallWarning{
 342				Type:    ai.CallWarningTypeUnsupportedSetting,
 343				Setting: "ServiceTier",
 344				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 345			})
 346		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 347			params.ServiceTier = ""
 348			warnings = append(warnings, ai.CallWarning{
 349				Type:    ai.CallWarningTypeUnsupportedSetting,
 350				Setting: "ServiceTier",
 351				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 352			})
 353		}
 354	}
 355
 356	if len(call.Tools) > 0 {
 357		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 358		params.Tools = tools
 359		if toolChoice != nil {
 360			params.ToolChoice = *toolChoice
 361		}
 362		warnings = append(warnings, toolWarnings...)
 363	}
 364	return params, warnings, nil
 365}
 366
 367func (o languageModel) handleError(err error) error {
 368	var apiErr *openai.Error
 369	if errors.As(err, &apiErr) {
 370		requestDump := apiErr.DumpRequest(true)
 371		responseDump := apiErr.DumpResponse(true)
 372		headers := map[string]string{}
 373		for k, h := range apiErr.Response.Header {
 374			v := h[len(h)-1]
 375			headers[strings.ToLower(k)] = v
 376		}
 377		return ai.NewAPICallError(
 378			apiErr.Message,
 379			apiErr.Request.URL.String(),
 380			string(requestDump),
 381			apiErr.StatusCode,
 382			headers,
 383			string(responseDump),
 384			apiErr,
 385			false,
 386		)
 387	}
 388	return err
 389}
 390
 391// Generate implements ai.LanguageModel.
 392func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 393	params, warnings, err := o.prepareParams(call)
 394	if err != nil {
 395		return nil, err
 396	}
 397	response, err := o.client.Chat.Completions.New(ctx, *params)
 398	if err != nil {
 399		return nil, o.handleError(err)
 400	}
 401
 402	if len(response.Choices) == 0 {
 403		return nil, errors.New("no response generated")
 404	}
 405	choice := response.Choices[0]
 406	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 407	text := choice.Message.Content
 408	if text != "" {
 409		content = append(content, ai.TextContent{
 410			Text: text,
 411		})
 412	}
 413
 414	for _, tc := range choice.Message.ToolCalls {
 415		toolCallID := tc.ID
 416		if toolCallID == "" {
 417			toolCallID = uuid.NewString()
 418		}
 419		content = append(content, ai.ToolCallContent{
 420			ProviderExecuted: false, // TODO: update when handling other tools
 421			ToolCallID:       toolCallID,
 422			ToolName:         tc.Function.Name,
 423			Input:            tc.Function.Arguments,
 424		})
 425	}
 426	// Handle annotations/citations
 427	for _, annotation := range choice.Message.Annotations {
 428		if annotation.Type == "url_citation" {
 429			content = append(content, ai.SourceContent{
 430				SourceType: ai.SourceTypeURL,
 431				ID:         uuid.NewString(),
 432				URL:        annotation.URLCitation.URL,
 433				Title:      annotation.URLCitation.Title,
 434			})
 435		}
 436	}
 437
 438	completionTokenDetails := response.Usage.CompletionTokensDetails
 439	promptTokenDetails := response.Usage.PromptTokensDetails
 440
 441	// Build provider metadata
 442	providerMetadata := ai.ProviderMetadata{
 443		"openai": make(map[string]any),
 444	}
 445
 446	// Add logprobs if available
 447	if len(choice.Logprobs.Content) > 0 {
 448		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
 449	}
 450
 451	// Add prediction tokens if available
 452	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 453		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 454			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 455		}
 456		if completionTokenDetails.RejectedPredictionTokens > 0 {
 457			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 458		}
 459	}
 460
 461	return &ai.Response{
 462		Content: content,
 463		Usage: ai.Usage{
 464			InputTokens:     response.Usage.PromptTokens,
 465			OutputTokens:    response.Usage.CompletionTokens,
 466			TotalTokens:     response.Usage.TotalTokens,
 467			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 468			CacheReadTokens: promptTokenDetails.CachedTokens,
 469		},
 470		FinishReason:     mapOpenAiFinishReason(choice.FinishReason),
 471		ProviderMetadata: providerMetadata,
 472		Warnings:         warnings,
 473	}, nil
 474}
 475
 476type toolCall struct {
 477	id          string
 478	name        string
 479	arguments   string
 480	hasFinished bool
 481}
 482
 483// Stream implements ai.LanguageModel.
 484func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 485	params, warnings, err := o.prepareParams(call)
 486	if err != nil {
 487		return nil, err
 488	}
 489
 490	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 491		IncludeUsage: openai.Bool(true),
 492	}
 493
 494	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 495	isActiveText := false
 496	toolCalls := make(map[int64]toolCall)
 497
 498	// Build provider metadata for streaming
 499	streamProviderMetadata := ai.ProviderMetadata{
 500		"openai": make(map[string]any),
 501	}
 502
 503	acc := openai.ChatCompletionAccumulator{}
 504	var usage ai.Usage
 505	return func(yield func(ai.StreamPart) bool) {
 506		if len(warnings) > 0 {
 507			if !yield(ai.StreamPart{
 508				Type:     ai.StreamPartTypeWarnings,
 509				Warnings: warnings,
 510			}) {
 511				return
 512			}
 513		}
 514		for stream.Next() {
 515			chunk := stream.Current()
 516			acc.AddChunk(chunk)
 517			if chunk.Usage.TotalTokens > 0 {
 518				// we do this here because the acc does not add prompt details
 519				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 520				promptTokenDetails := chunk.Usage.PromptTokensDetails
 521				usage = ai.Usage{
 522					InputTokens:     chunk.Usage.PromptTokens,
 523					OutputTokens:    chunk.Usage.CompletionTokens,
 524					TotalTokens:     chunk.Usage.TotalTokens,
 525					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 526					CacheReadTokens: promptTokenDetails.CachedTokens,
 527				}
 528
 529				// Add prediction tokens if available
 530				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 531					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 532						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
 533					}
 534					if completionTokenDetails.RejectedPredictionTokens > 0 {
 535						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
 536					}
 537				}
 538			}
 539			if len(chunk.Choices) == 0 {
 540				continue
 541			}
 542			for _, choice := range chunk.Choices {
 543				switch {
 544				case choice.Delta.Content != "":
 545					if !isActiveText {
 546						isActiveText = true
 547						if !yield(ai.StreamPart{
 548							Type: ai.StreamPartTypeTextStart,
 549							ID:   "0",
 550						}) {
 551							return
 552						}
 553					}
 554					if !yield(ai.StreamPart{
 555						Type:  ai.StreamPartTypeTextDelta,
 556						ID:    "0",
 557						Delta: choice.Delta.Content,
 558					}) {
 559						return
 560					}
 561				case len(choice.Delta.ToolCalls) > 0:
 562					if isActiveText {
 563						isActiveText = false
 564						if !yield(ai.StreamPart{
 565							Type: ai.StreamPartTypeTextEnd,
 566							ID:   "0",
 567						}) {
 568							return
 569						}
 570					}
 571
 572					for _, toolCallDelta := range choice.Delta.ToolCalls {
 573						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 574							if existingToolCall.hasFinished {
 575								continue
 576							}
 577							if toolCallDelta.Function.Arguments != "" {
 578								existingToolCall.arguments += toolCallDelta.Function.Arguments
 579							}
 580							if !yield(ai.StreamPart{
 581								Type:  ai.StreamPartTypeToolInputDelta,
 582								ID:    existingToolCall.id,
 583								Delta: toolCallDelta.Function.Arguments,
 584							}) {
 585								return
 586							}
 587							toolCalls[toolCallDelta.Index] = existingToolCall
 588							if xjson.IsValid(existingToolCall.arguments) {
 589								if !yield(ai.StreamPart{
 590									Type: ai.StreamPartTypeToolInputEnd,
 591									ID:   existingToolCall.id,
 592								}) {
 593									return
 594								}
 595
 596								if !yield(ai.StreamPart{
 597									Type:          ai.StreamPartTypeToolCall,
 598									ID:            existingToolCall.id,
 599									ToolCallName:  existingToolCall.name,
 600									ToolCallInput: existingToolCall.arguments,
 601								}) {
 602									return
 603								}
 604								existingToolCall.hasFinished = true
 605								toolCalls[toolCallDelta.Index] = existingToolCall
 606							}
 607						} else {
 608							// Does not exist
 609							var err error
 610							if toolCallDelta.Type != "function" {
 611								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 612							}
 613							if toolCallDelta.ID == "" {
 614								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 615							}
 616							if toolCallDelta.Function.Name == "" {
 617								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 618							}
 619							if err != nil {
 620								yield(ai.StreamPart{
 621									Type:  ai.StreamPartTypeError,
 622									Error: o.handleError(stream.Err()),
 623								})
 624								return
 625							}
 626
 627							if !yield(ai.StreamPart{
 628								Type:         ai.StreamPartTypeToolInputStart,
 629								ID:           toolCallDelta.ID,
 630								ToolCallName: toolCallDelta.Function.Name,
 631							}) {
 632								return
 633							}
 634							toolCalls[toolCallDelta.Index] = toolCall{
 635								id:        toolCallDelta.ID,
 636								name:      toolCallDelta.Function.Name,
 637								arguments: toolCallDelta.Function.Arguments,
 638							}
 639
 640							exTc := toolCalls[toolCallDelta.Index]
 641							if exTc.arguments != "" {
 642								if !yield(ai.StreamPart{
 643									Type:  ai.StreamPartTypeToolInputDelta,
 644									ID:    exTc.id,
 645									Delta: exTc.arguments,
 646								}) {
 647									return
 648								}
 649								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
 650									if !yield(ai.StreamPart{
 651										Type: ai.StreamPartTypeToolInputEnd,
 652										ID:   toolCallDelta.ID,
 653									}) {
 654										return
 655									}
 656
 657									if !yield(ai.StreamPart{
 658										Type:          ai.StreamPartTypeToolCall,
 659										ID:            exTc.id,
 660										ToolCallName:  exTc.name,
 661										ToolCallInput: exTc.arguments,
 662									}) {
 663										return
 664									}
 665									exTc.hasFinished = true
 666									toolCalls[toolCallDelta.Index] = exTc
 667								}
 668							}
 669							continue
 670						}
 671					}
 672				}
 673			}
 674
 675			// Check for annotations in the delta's raw JSON
 676			for _, choice := range chunk.Choices {
 677				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 678					for _, annotation := range annotations {
 679						if annotation.Type == "url_citation" {
 680							if !yield(ai.StreamPart{
 681								Type:       ai.StreamPartTypeSource,
 682								ID:         uuid.NewString(),
 683								SourceType: ai.SourceTypeURL,
 684								URL:        annotation.URLCitation.URL,
 685								Title:      annotation.URLCitation.Title,
 686							}) {
 687								return
 688							}
 689						}
 690					}
 691				}
 692			}
 693		}
 694		err := stream.Err()
 695		if err == nil || errors.Is(err, io.EOF) {
 696			// finished
 697			if isActiveText {
 698				isActiveText = false
 699				if !yield(ai.StreamPart{
 700					Type: ai.StreamPartTypeTextEnd,
 701					ID:   "0",
 702				}) {
 703					return
 704				}
 705			}
 706
 707			// Add logprobs if available
 708			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 709				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
 710			}
 711
 712			// Handle annotations/citations from accumulated response
 713			if len(acc.Choices) > 0 {
 714				for _, annotation := range acc.Choices[0].Message.Annotations {
 715					if annotation.Type == "url_citation" {
 716						if !yield(ai.StreamPart{
 717							Type:       ai.StreamPartTypeSource,
 718							ID:         acc.ID,
 719							SourceType: ai.SourceTypeURL,
 720							URL:        annotation.URLCitation.URL,
 721							Title:      annotation.URLCitation.Title,
 722						}) {
 723							return
 724						}
 725					}
 726				}
 727			}
 728
 729			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 730			yield(ai.StreamPart{
 731				Type:             ai.StreamPartTypeFinish,
 732				Usage:            usage,
 733				FinishReason:     finishReason,
 734				ProviderMetadata: streamProviderMetadata,
 735			})
 736			return
 737		} else {
 738			yield(ai.StreamPart{
 739				Type:  ai.StreamPartTypeError,
 740				Error: o.handleError(err),
 741			})
 742			return
 743		}
 744	}, nil
 745}
 746
 747func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 748	switch finishReason {
 749	case "stop":
 750		return ai.FinishReasonStop
 751	case "length":
 752		return ai.FinishReasonLength
 753	case "content_filter":
 754		return ai.FinishReasonContentFilter
 755	case "function_call", "tool_calls":
 756		return ai.FinishReasonToolCalls
 757	default:
 758		return ai.FinishReasonUnknown
 759	}
 760}
 761
 762func isReasoningModel(modelID string) bool {
 763	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 764}
 765
 766func isSearchPreviewModel(modelID string) bool {
 767	return strings.Contains(modelID, "search-preview")
 768}
 769
 770func supportsFlexProcessing(modelID string) bool {
 771	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 772}
 773
 774func supportsPriorityProcessing(modelID string) bool {
 775	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 776		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 777		strings.HasPrefix(modelID, "o4-mini")
 778}
 779
 780func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 781	for _, tool := range tools {
 782		if tool.GetType() == ai.ToolTypeFunction {
 783			ft, ok := tool.(ai.FunctionTool)
 784			if !ok {
 785				continue
 786			}
 787			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 788				OfFunction: &openai.ChatCompletionFunctionToolParam{
 789					Function: shared.FunctionDefinitionParam{
 790						Name:        ft.Name,
 791						Description: param.NewOpt(ft.Description),
 792						Parameters:  openai.FunctionParameters(ft.InputSchema),
 793						Strict:      param.NewOpt(false),
 794					},
 795					Type: "function",
 796				},
 797			})
 798			continue
 799		}
 800
 801		// TODO: handle provider tool calls
 802		warnings = append(warnings, ai.CallWarning{
 803			Type:    ai.CallWarningTypeUnsupportedTool,
 804			Tool:    tool,
 805			Message: "tool is not supported",
 806		})
 807	}
 808	if toolChoice == nil {
 809		return openAiTools, openAiToolChoice, warnings
 810	}
 811
 812	switch *toolChoice {
 813	case ai.ToolChoiceAuto:
 814		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 815			OfAuto: param.NewOpt("auto"),
 816		}
 817	case ai.ToolChoiceNone:
 818		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 819			OfAuto: param.NewOpt("none"),
 820		}
 821	default:
 822		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 823			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 824				Type: "function",
 825				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 826					Name: string(*toolChoice),
 827				},
 828			},
 829		}
 830	}
 831	return openAiTools, openAiToolChoice, warnings
 832}
 833
 834func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 835	var messages []openai.ChatCompletionMessageParamUnion
 836	var warnings []ai.CallWarning
 837	for _, msg := range prompt {
 838		switch msg.Role {
 839		case ai.MessageRoleSystem:
 840			var systemPromptParts []string
 841			for _, c := range msg.Content {
 842				if c.GetType() != ai.ContentTypeText {
 843					warnings = append(warnings, ai.CallWarning{
 844						Type:    ai.CallWarningTypeOther,
 845						Message: "system prompt can only have text content",
 846					})
 847					continue
 848				}
 849				textPart, ok := ai.AsContentType[ai.TextPart](c)
 850				if !ok {
 851					warnings = append(warnings, ai.CallWarning{
 852						Type:    ai.CallWarningTypeOther,
 853						Message: "system prompt text part does not have the right type",
 854					})
 855					continue
 856				}
 857				text := textPart.Text
 858				if strings.TrimSpace(text) != "" {
 859					systemPromptParts = append(systemPromptParts, textPart.Text)
 860				}
 861			}
 862			if len(systemPromptParts) == 0 {
 863				warnings = append(warnings, ai.CallWarning{
 864					Type:    ai.CallWarningTypeOther,
 865					Message: "system prompt has no text parts",
 866				})
 867				continue
 868			}
 869			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 870		case ai.MessageRoleUser:
 871			// simple user message just text content
 872			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 873				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 874				if !ok {
 875					warnings = append(warnings, ai.CallWarning{
 876						Type:    ai.CallWarningTypeOther,
 877						Message: "user message text part does not have the right type",
 878					})
 879					continue
 880				}
 881				messages = append(messages, openai.UserMessage(textPart.Text))
 882				continue
 883			}
 884			// text content and attachments
 885			// for now we only support image content later we need to check
 886			// TODO: add the supported media types to the language model so we
 887			//  can use that to validate the data here.
 888			var content []openai.ChatCompletionContentPartUnionParam
 889			for _, c := range msg.Content {
 890				switch c.GetType() {
 891				case ai.ContentTypeText:
 892					textPart, ok := ai.AsContentType[ai.TextPart](c)
 893					if !ok {
 894						warnings = append(warnings, ai.CallWarning{
 895							Type:    ai.CallWarningTypeOther,
 896							Message: "user message text part does not have the right type",
 897						})
 898						continue
 899					}
 900					content = append(content, openai.ChatCompletionContentPartUnionParam{
 901						OfText: &openai.ChatCompletionContentPartTextParam{
 902							Text: textPart.Text,
 903						},
 904					})
 905				case ai.ContentTypeFile:
 906					filePart, ok := ai.AsContentType[ai.FilePart](c)
 907					if !ok {
 908						warnings = append(warnings, ai.CallWarning{
 909							Type:    ai.CallWarningTypeOther,
 910							Message: "user message file part does not have the right type",
 911						})
 912						continue
 913					}
 914
 915					switch {
 916					case strings.HasPrefix(filePart.MediaType, "image/"):
 917						// Handle image files
 918						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 919						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 920						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 921
 922						// Check for provider-specific options like image detail
 923						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 924							if detail, ok := providerOptions["imageDetail"].(string); ok {
 925								imageURL.Detail = detail
 926							}
 927						}
 928
 929						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 930						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 931
 932					case filePart.MediaType == "audio/wav":
 933						// Handle WAV audio files
 934						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 935						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 936							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 937								Data:   base64Encoded,
 938								Format: "wav",
 939							},
 940						}
 941						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 942
 943					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 944						// Handle MP3 audio files
 945						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 946						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 947							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 948								Data:   base64Encoded,
 949								Format: "mp3",
 950							},
 951						}
 952						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 953
 954					case filePart.MediaType == "application/pdf":
 955						// Handle PDF files
 956						dataStr := string(filePart.Data)
 957
 958						// Check if data looks like a file ID (starts with "file-")
 959						if strings.HasPrefix(dataStr, "file-") {
 960							fileBlock := openai.ChatCompletionContentPartFileParam{
 961								File: openai.ChatCompletionContentPartFileFileParam{
 962									FileID: param.NewOpt(dataStr),
 963								},
 964							}
 965							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 966						} else {
 967							// Handle as base64 data
 968							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 969							data := "data:application/pdf;base64," + base64Encoded
 970
 971							filename := filePart.Filename
 972							if filename == "" {
 973								// Generate default filename based on content index
 974								filename = fmt.Sprintf("part-%d.pdf", len(content))
 975							}
 976
 977							fileBlock := openai.ChatCompletionContentPartFileParam{
 978								File: openai.ChatCompletionContentPartFileFileParam{
 979									Filename: param.NewOpt(filename),
 980									FileData: param.NewOpt(data),
 981								},
 982							}
 983							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 984						}
 985
 986					default:
 987						warnings = append(warnings, ai.CallWarning{
 988							Type:    ai.CallWarningTypeOther,
 989							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
 990						})
 991					}
 992				}
 993			}
 994			messages = append(messages, openai.UserMessage(content))
 995		case ai.MessageRoleAssistant:
 996			// simple assistant message just text content
 997			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 998				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 999				if !ok {
1000					warnings = append(warnings, ai.CallWarning{
1001						Type:    ai.CallWarningTypeOther,
1002						Message: "assistant message text part does not have the right type",
1003					})
1004					continue
1005				}
1006				messages = append(messages, openai.AssistantMessage(textPart.Text))
1007				continue
1008			}
1009			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1010				Role: "assistant",
1011			}
1012			for _, c := range msg.Content {
1013				switch c.GetType() {
1014				case ai.ContentTypeText:
1015					textPart, ok := ai.AsContentType[ai.TextPart](c)
1016					if !ok {
1017						warnings = append(warnings, ai.CallWarning{
1018							Type:    ai.CallWarningTypeOther,
1019							Message: "assistant message text part does not have the right type",
1020						})
1021						continue
1022					}
1023					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1024						OfString: param.NewOpt(textPart.Text),
1025					}
1026				case ai.ContentTypeToolCall:
1027					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1028					if !ok {
1029						warnings = append(warnings, ai.CallWarning{
1030							Type:    ai.CallWarningTypeOther,
1031							Message: "assistant message tool part does not have the right type",
1032						})
1033						continue
1034					}
1035					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1036						openai.ChatCompletionMessageToolCallUnionParam{
1037							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1038								ID:   toolCallPart.ToolCallID,
1039								Type: "function",
1040								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1041									Name:      toolCallPart.ToolName,
1042									Arguments: toolCallPart.Input,
1043								},
1044							},
1045						})
1046				}
1047			}
1048			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1049				OfAssistant: &assistantMsg,
1050			})
1051		case ai.MessageRoleTool:
1052			for _, c := range msg.Content {
1053				if c.GetType() != ai.ContentTypeToolResult {
1054					warnings = append(warnings, ai.CallWarning{
1055						Type:    ai.CallWarningTypeOther,
1056						Message: "tool message can only have tool result content",
1057					})
1058					continue
1059				}
1060
1061				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1062				if !ok {
1063					warnings = append(warnings, ai.CallWarning{
1064						Type:    ai.CallWarningTypeOther,
1065						Message: "tool message result part does not have the right type",
1066					})
1067					continue
1068				}
1069
1070				switch toolResultPart.Output.GetType() {
1071				case ai.ToolResultContentTypeText:
1072					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1073					if !ok {
1074						warnings = append(warnings, ai.CallWarning{
1075							Type:    ai.CallWarningTypeOther,
1076							Message: "tool result output does not have the right type",
1077						})
1078						continue
1079					}
1080					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1081				case ai.ToolResultContentTypeError:
1082					// TODO: check if better handling is needed
1083					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1084					if !ok {
1085						warnings = append(warnings, ai.CallWarning{
1086							Type:    ai.CallWarningTypeOther,
1087							Message: "tool result output does not have the right type",
1088						})
1089						continue
1090					}
1091					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1092				}
1093			}
1094		}
1095	}
1096	return messages, warnings
1097}
1098
1099// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1100func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1101	var annotations []openai.ChatCompletionMessageAnnotation
1102
1103	// Parse the raw JSON to extract annotations
1104	var deltaData map[string]any
1105	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1106		return annotations
1107	}
1108
1109	// Check if annotations exist in the delta
1110	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1111		for _, annotationData := range annotationsData {
1112			if annotationMap, ok := annotationData.(map[string]any); ok {
1113				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1114					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1115						annotation := openai.ChatCompletionMessageAnnotation{
1116							Type: "url_citation",
1117							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1118								URL:   urlCitationData["url"].(string),
1119								Title: urlCitationData["title"].(string),
1120							},
1121						}
1122						annotations = append(annotations, annotation)
1123					}
1124				}
1125			}
1126		}
1127	}
1128
1129	return annotations
1130}