openai.go

   1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/ai/ai"
  15	xjson "github.com/charmbracelet/x/json"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23type provider struct {
  24	options options
  25}
  26
  27type options struct {
  28	baseURL      string
  29	apiKey       string
  30	organization string
  31	project      string
  32	name         string
  33	headers      map[string]string
  34	client       option.HTTPClient
  35}
  36
  37type Option = func(*options)
  38
  39func New(opts ...Option) ai.Provider {
  40	options := options{
  41		headers: map[string]string{},
  42	}
  43	for _, o := range opts {
  44		o(&options)
  45	}
  46
  47	options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1")
  48	options.name = cmp.Or(options.name, "openai")
  49
  50	if options.organization != "" {
  51		options.headers["OpenAi-Organization"] = options.organization
  52	}
  53	if options.project != "" {
  54		options.headers["OpenAi-Project"] = options.project
  55	}
  56
  57	return &provider{options: options}
  58}
  59
  60func WithBaseURL(baseURL string) Option {
  61	return func(o *options) {
  62		o.baseURL = baseURL
  63	}
  64}
  65
  66func WithAPIKey(apiKey string) Option {
  67	return func(o *options) {
  68		o.apiKey = apiKey
  69	}
  70}
  71
  72func WithOrganization(organization string) Option {
  73	return func(o *options) {
  74		o.organization = organization
  75	}
  76}
  77
  78func WithProject(project string) Option {
  79	return func(o *options) {
  80		o.project = project
  81	}
  82}
  83
  84func WithName(name string) Option {
  85	return func(o *options) {
  86		o.name = name
  87	}
  88}
  89
  90func WithHeaders(headers map[string]string) Option {
  91	return func(o *options) {
  92		maps.Copy(o.headers, headers)
  93	}
  94}
  95
  96func WithHTTPClient(client option.HTTPClient) Option {
  97	return func(o *options) {
  98		o.client = client
  99	}
 100}
 101
 102// LanguageModel implements ai.Provider.
 103func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 104	openaiClientOptions := []option.RequestOption{}
 105	if o.options.apiKey != "" {
 106		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 107	}
 108	if o.options.baseURL != "" {
 109		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 110	}
 111
 112	for key, value := range o.options.headers {
 113		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 114	}
 115
 116	if o.options.client != nil {
 117		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 118	}
 119
 120	return languageModel{
 121		modelID:  modelID,
 122		provider: fmt.Sprintf("%s.chat", o.options.name),
 123		options:  o.options,
 124		client:   openai.NewClient(openaiClientOptions...),
 125	}, nil
 126}
 127
 128type languageModel struct {
 129	provider string
 130	modelID  string
 131	client   openai.Client
 132	options  options
 133}
 134
 135// Model implements ai.LanguageModel.
 136func (o languageModel) Model() string {
 137	return o.modelID
 138}
 139
 140// Provider implements ai.LanguageModel.
 141func (o languageModel) Provider() string {
 142	return o.provider
 143}
 144
 145func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 146	params := &openai.ChatCompletionNewParams{}
 147	messages, warnings := toPrompt(call.Prompt)
 148	providerOptions := &ProviderOptions{}
 149	if v, ok := call.ProviderOptions["openai"]; ok {
 150		providerOptions, ok = v.(*ProviderOptions)
 151		if !ok {
 152			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 153		}
 154	}
 155	if call.TopK != nil {
 156		warnings = append(warnings, ai.CallWarning{
 157			Type:    ai.CallWarningTypeUnsupportedSetting,
 158			Setting: "top_k",
 159		})
 160	}
 161	params.Messages = messages
 162	params.Model = o.modelID
 163	if providerOptions.LogitBias != nil {
 164		params.LogitBias = providerOptions.LogitBias
 165	}
 166	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 167		providerOptions.LogProbs = nil
 168	}
 169	if providerOptions.LogProbs != nil {
 170		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 171	}
 172	if providerOptions.TopLogProbs != nil {
 173		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 174	}
 175	if providerOptions.User != nil {
 176		params.User = param.NewOpt(*providerOptions.User)
 177	}
 178	if providerOptions.ParallelToolCalls != nil {
 179		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 180	}
 181
 182	if call.MaxOutputTokens != nil {
 183		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 184	}
 185	if call.Temperature != nil {
 186		params.Temperature = param.NewOpt(*call.Temperature)
 187	}
 188	if call.TopP != nil {
 189		params.TopP = param.NewOpt(*call.TopP)
 190	}
 191	if call.FrequencyPenalty != nil {
 192		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 193	}
 194	if call.PresencePenalty != nil {
 195		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 196	}
 197
 198	if providerOptions.MaxCompletionTokens != nil {
 199		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 200	}
 201
 202	if providerOptions.TextVerbosity != nil {
 203		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 204	}
 205	if providerOptions.Prediction != nil {
 206		// Convert map[string]any to ChatCompletionPredictionContentParam
 207		if content, ok := providerOptions.Prediction["content"]; ok {
 208			if contentStr, ok := content.(string); ok {
 209				params.Prediction = openai.ChatCompletionPredictionContentParam{
 210					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 211						OfString: param.NewOpt(contentStr),
 212					},
 213				}
 214			}
 215		}
 216	}
 217	if providerOptions.Store != nil {
 218		params.Store = param.NewOpt(*providerOptions.Store)
 219	}
 220	if providerOptions.Metadata != nil {
 221		// Convert map[string]any to map[string]string
 222		metadata := make(map[string]string)
 223		for k, v := range providerOptions.Metadata {
 224			if str, ok := v.(string); ok {
 225				metadata[k] = str
 226			}
 227		}
 228		params.Metadata = metadata
 229	}
 230	if providerOptions.PromptCacheKey != nil {
 231		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 232	}
 233	if providerOptions.SafetyIdentifier != nil {
 234		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 235	}
 236	if providerOptions.ServiceTier != nil {
 237		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 238	}
 239
 240	if providerOptions.ReasoningEffort != nil {
 241		switch *providerOptions.ReasoningEffort {
 242		case ReasoningEffortMinimal:
 243			params.ReasoningEffort = shared.ReasoningEffortMinimal
 244		case ReasoningEffortLow:
 245			params.ReasoningEffort = shared.ReasoningEffortLow
 246		case ReasoningEffortMedium:
 247			params.ReasoningEffort = shared.ReasoningEffortMedium
 248		case ReasoningEffortHigh:
 249			params.ReasoningEffort = shared.ReasoningEffortHigh
 250		default:
 251			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 252		}
 253	}
 254
 255	if isReasoningModel(o.modelID) {
 256		// remove unsupported settings for reasoning models
 257		// see https://platform.openai.com/docs/guides/reasoning#limitations
 258		if call.Temperature != nil {
 259			params.Temperature = param.Opt[float64]{}
 260			warnings = append(warnings, ai.CallWarning{
 261				Type:    ai.CallWarningTypeUnsupportedSetting,
 262				Setting: "temperature",
 263				Details: "temperature is not supported for reasoning models",
 264			})
 265		}
 266		if call.TopP != nil {
 267			params.TopP = param.Opt[float64]{}
 268			warnings = append(warnings, ai.CallWarning{
 269				Type:    ai.CallWarningTypeUnsupportedSetting,
 270				Setting: "TopP",
 271				Details: "TopP is not supported for reasoning models",
 272			})
 273		}
 274		if call.FrequencyPenalty != nil {
 275			params.FrequencyPenalty = param.Opt[float64]{}
 276			warnings = append(warnings, ai.CallWarning{
 277				Type:    ai.CallWarningTypeUnsupportedSetting,
 278				Setting: "FrequencyPenalty",
 279				Details: "FrequencyPenalty is not supported for reasoning models",
 280			})
 281		}
 282		if call.PresencePenalty != nil {
 283			params.PresencePenalty = param.Opt[float64]{}
 284			warnings = append(warnings, ai.CallWarning{
 285				Type:    ai.CallWarningTypeUnsupportedSetting,
 286				Setting: "PresencePenalty",
 287				Details: "PresencePenalty is not supported for reasoning models",
 288			})
 289		}
 290		if providerOptions.LogitBias != nil {
 291			params.LogitBias = nil
 292			warnings = append(warnings, ai.CallWarning{
 293				Type:    ai.CallWarningTypeUnsupportedSetting,
 294				Setting: "LogitBias",
 295				Message: "LogitBias is not supported for reasoning models",
 296			})
 297		}
 298		if providerOptions.LogProbs != nil {
 299			params.Logprobs = param.Opt[bool]{}
 300			warnings = append(warnings, ai.CallWarning{
 301				Type:    ai.CallWarningTypeUnsupportedSetting,
 302				Setting: "Logprobs",
 303				Message: "Logprobs is not supported for reasoning models",
 304			})
 305		}
 306		if providerOptions.TopLogProbs != nil {
 307			params.TopLogprobs = param.Opt[int64]{}
 308			warnings = append(warnings, ai.CallWarning{
 309				Type:    ai.CallWarningTypeUnsupportedSetting,
 310				Setting: "TopLogprobs",
 311				Message: "TopLogprobs is not supported for reasoning models",
 312			})
 313		}
 314
 315		// reasoning models use max_completion_tokens instead of max_tokens
 316		if call.MaxOutputTokens != nil {
 317			if providerOptions.MaxCompletionTokens == nil {
 318				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 319			}
 320			params.MaxTokens = param.Opt[int64]{}
 321		}
 322	}
 323
 324	// Handle search preview models
 325	if isSearchPreviewModel(o.modelID) {
 326		if call.Temperature != nil {
 327			params.Temperature = param.Opt[float64]{}
 328			warnings = append(warnings, ai.CallWarning{
 329				Type:    ai.CallWarningTypeUnsupportedSetting,
 330				Setting: "temperature",
 331				Details: "temperature is not supported for the search preview models and has been removed.",
 332			})
 333		}
 334	}
 335
 336	// Handle service tier validation
 337	if providerOptions.ServiceTier != nil {
 338		serviceTier := *providerOptions.ServiceTier
 339		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 340			params.ServiceTier = ""
 341			warnings = append(warnings, ai.CallWarning{
 342				Type:    ai.CallWarningTypeUnsupportedSetting,
 343				Setting: "ServiceTier",
 344				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 345			})
 346		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 347			params.ServiceTier = ""
 348			warnings = append(warnings, ai.CallWarning{
 349				Type:    ai.CallWarningTypeUnsupportedSetting,
 350				Setting: "ServiceTier",
 351				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 352			})
 353		}
 354	}
 355
 356	if len(call.Tools) > 0 {
 357		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 358		params.Tools = tools
 359		if toolChoice != nil {
 360			params.ToolChoice = *toolChoice
 361		}
 362		warnings = append(warnings, toolWarnings...)
 363	}
 364	return params, warnings, nil
 365}
 366
 367func (o languageModel) handleError(err error) error {
 368	var apiErr *openai.Error
 369	if errors.As(err, &apiErr) {
 370		requestDump := apiErr.DumpRequest(true)
 371		responseDump := apiErr.DumpResponse(true)
 372		headers := map[string]string{}
 373		for k, h := range apiErr.Response.Header {
 374			v := h[len(h)-1]
 375			headers[strings.ToLower(k)] = v
 376		}
 377		return ai.NewAPICallError(
 378			apiErr.Message,
 379			apiErr.Request.URL.String(),
 380			string(requestDump),
 381			apiErr.StatusCode,
 382			headers,
 383			string(responseDump),
 384			apiErr,
 385			false,
 386		)
 387	}
 388	return err
 389}
 390
 391// Generate implements ai.LanguageModel.
 392func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 393	params, warnings, err := o.prepareParams(call)
 394	if err != nil {
 395		return nil, err
 396	}
 397	response, err := o.client.Chat.Completions.New(ctx, *params)
 398	if err != nil {
 399		return nil, o.handleError(err)
 400	}
 401
 402	if len(response.Choices) == 0 {
 403		return nil, errors.New("no response generated")
 404	}
 405	choice := response.Choices[0]
 406	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 407	text := choice.Message.Content
 408	if text != "" {
 409		content = append(content, ai.TextContent{
 410			Text: text,
 411		})
 412	}
 413
 414	for _, tc := range choice.Message.ToolCalls {
 415		toolCallID := tc.ID
 416		if toolCallID == "" {
 417			toolCallID = uuid.NewString()
 418		}
 419		content = append(content, ai.ToolCallContent{
 420			ProviderExecuted: false, // TODO: update when handling other tools
 421			ToolCallID:       toolCallID,
 422			ToolName:         tc.Function.Name,
 423			Input:            tc.Function.Arguments,
 424		})
 425	}
 426	// Handle annotations/citations
 427	for _, annotation := range choice.Message.Annotations {
 428		if annotation.Type == "url_citation" {
 429			content = append(content, ai.SourceContent{
 430				SourceType: ai.SourceTypeURL,
 431				ID:         uuid.NewString(),
 432				URL:        annotation.URLCitation.URL,
 433				Title:      annotation.URLCitation.Title,
 434			})
 435		}
 436	}
 437
 438	completionTokenDetails := response.Usage.CompletionTokensDetails
 439	promptTokenDetails := response.Usage.PromptTokensDetails
 440
 441	// Build provider metadata
 442	providerMetadata := &ProviderMetadata{}
 443	// Add logprobs if available
 444	if len(choice.Logprobs.Content) > 0 {
 445		providerMetadata.Logprobs = choice.Logprobs.Content
 446	}
 447
 448	// Add prediction tokens if available
 449	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 450		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 451			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 452		}
 453		if completionTokenDetails.RejectedPredictionTokens > 0 {
 454			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 455		}
 456	}
 457
 458	return &ai.Response{
 459		Content: content,
 460		Usage: ai.Usage{
 461			InputTokens:     response.Usage.PromptTokens,
 462			OutputTokens:    response.Usage.CompletionTokens,
 463			TotalTokens:     response.Usage.TotalTokens,
 464			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 465			CacheReadTokens: promptTokenDetails.CachedTokens,
 466		},
 467		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
 468		ProviderMetadata: ai.ProviderMetadata{
 469			"openai": providerMetadata,
 470		},
 471		Warnings: warnings,
 472	}, nil
 473}
 474
 475type toolCall struct {
 476	id          string
 477	name        string
 478	arguments   string
 479	hasFinished bool
 480}
 481
 482// Stream implements ai.LanguageModel.
 483func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 484	params, warnings, err := o.prepareParams(call)
 485	if err != nil {
 486		return nil, err
 487	}
 488
 489	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 490		IncludeUsage: openai.Bool(true),
 491	}
 492
 493	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 494	isActiveText := false
 495	toolCalls := make(map[int64]toolCall)
 496
 497	// Build provider metadata for streaming
 498	streamProviderMetadata := &ProviderMetadata{}
 499	acc := openai.ChatCompletionAccumulator{}
 500	var usage ai.Usage
 501	return func(yield func(ai.StreamPart) bool) {
 502		if len(warnings) > 0 {
 503			if !yield(ai.StreamPart{
 504				Type:     ai.StreamPartTypeWarnings,
 505				Warnings: warnings,
 506			}) {
 507				return
 508			}
 509		}
 510		for stream.Next() {
 511			chunk := stream.Current()
 512			acc.AddChunk(chunk)
 513			if chunk.Usage.TotalTokens > 0 {
 514				// we do this here because the acc does not add prompt details
 515				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 516				promptTokenDetails := chunk.Usage.PromptTokensDetails
 517				usage = ai.Usage{
 518					InputTokens:     chunk.Usage.PromptTokens,
 519					OutputTokens:    chunk.Usage.CompletionTokens,
 520					TotalTokens:     chunk.Usage.TotalTokens,
 521					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 522					CacheReadTokens: promptTokenDetails.CachedTokens,
 523				}
 524
 525				// Add prediction tokens if available
 526				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 527					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 528						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 529					}
 530					if completionTokenDetails.RejectedPredictionTokens > 0 {
 531						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 532					}
 533				}
 534			}
 535			if len(chunk.Choices) == 0 {
 536				continue
 537			}
 538			for _, choice := range chunk.Choices {
 539				switch {
 540				case choice.Delta.Content != "":
 541					if !isActiveText {
 542						isActiveText = true
 543						if !yield(ai.StreamPart{
 544							Type: ai.StreamPartTypeTextStart,
 545							ID:   "0",
 546						}) {
 547							return
 548						}
 549					}
 550					if !yield(ai.StreamPart{
 551						Type:  ai.StreamPartTypeTextDelta,
 552						ID:    "0",
 553						Delta: choice.Delta.Content,
 554					}) {
 555						return
 556					}
 557				case len(choice.Delta.ToolCalls) > 0:
 558					if isActiveText {
 559						isActiveText = false
 560						if !yield(ai.StreamPart{
 561							Type: ai.StreamPartTypeTextEnd,
 562							ID:   "0",
 563						}) {
 564							return
 565						}
 566					}
 567
 568					for _, toolCallDelta := range choice.Delta.ToolCalls {
 569						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 570							if existingToolCall.hasFinished {
 571								continue
 572							}
 573							if toolCallDelta.Function.Arguments != "" {
 574								existingToolCall.arguments += toolCallDelta.Function.Arguments
 575							}
 576							if !yield(ai.StreamPart{
 577								Type:  ai.StreamPartTypeToolInputDelta,
 578								ID:    existingToolCall.id,
 579								Delta: toolCallDelta.Function.Arguments,
 580							}) {
 581								return
 582							}
 583							toolCalls[toolCallDelta.Index] = existingToolCall
 584							if xjson.IsValid(existingToolCall.arguments) {
 585								if !yield(ai.StreamPart{
 586									Type: ai.StreamPartTypeToolInputEnd,
 587									ID:   existingToolCall.id,
 588								}) {
 589									return
 590								}
 591
 592								if !yield(ai.StreamPart{
 593									Type:          ai.StreamPartTypeToolCall,
 594									ID:            existingToolCall.id,
 595									ToolCallName:  existingToolCall.name,
 596									ToolCallInput: existingToolCall.arguments,
 597								}) {
 598									return
 599								}
 600								existingToolCall.hasFinished = true
 601								toolCalls[toolCallDelta.Index] = existingToolCall
 602							}
 603						} else {
 604							// Does not exist
 605							var err error
 606							if toolCallDelta.Type != "function" {
 607								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 608							}
 609							if toolCallDelta.ID == "" {
 610								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 611							}
 612							if toolCallDelta.Function.Name == "" {
 613								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 614							}
 615							if err != nil {
 616								yield(ai.StreamPart{
 617									Type:  ai.StreamPartTypeError,
 618									Error: o.handleError(stream.Err()),
 619								})
 620								return
 621							}
 622
 623							if !yield(ai.StreamPart{
 624								Type:         ai.StreamPartTypeToolInputStart,
 625								ID:           toolCallDelta.ID,
 626								ToolCallName: toolCallDelta.Function.Name,
 627							}) {
 628								return
 629							}
 630							toolCalls[toolCallDelta.Index] = toolCall{
 631								id:        toolCallDelta.ID,
 632								name:      toolCallDelta.Function.Name,
 633								arguments: toolCallDelta.Function.Arguments,
 634							}
 635
 636							exTc := toolCalls[toolCallDelta.Index]
 637							if exTc.arguments != "" {
 638								if !yield(ai.StreamPart{
 639									Type:  ai.StreamPartTypeToolInputDelta,
 640									ID:    exTc.id,
 641									Delta: exTc.arguments,
 642								}) {
 643									return
 644								}
 645								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
 646									if !yield(ai.StreamPart{
 647										Type: ai.StreamPartTypeToolInputEnd,
 648										ID:   toolCallDelta.ID,
 649									}) {
 650										return
 651									}
 652
 653									if !yield(ai.StreamPart{
 654										Type:          ai.StreamPartTypeToolCall,
 655										ID:            exTc.id,
 656										ToolCallName:  exTc.name,
 657										ToolCallInput: exTc.arguments,
 658									}) {
 659										return
 660									}
 661									exTc.hasFinished = true
 662									toolCalls[toolCallDelta.Index] = exTc
 663								}
 664							}
 665							continue
 666						}
 667					}
 668				}
 669			}
 670
 671			// Check for annotations in the delta's raw JSON
 672			for _, choice := range chunk.Choices {
 673				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 674					for _, annotation := range annotations {
 675						if annotation.Type == "url_citation" {
 676							if !yield(ai.StreamPart{
 677								Type:       ai.StreamPartTypeSource,
 678								ID:         uuid.NewString(),
 679								SourceType: ai.SourceTypeURL,
 680								URL:        annotation.URLCitation.URL,
 681								Title:      annotation.URLCitation.Title,
 682							}) {
 683								return
 684							}
 685						}
 686					}
 687				}
 688			}
 689		}
 690		err := stream.Err()
 691		if err == nil || errors.Is(err, io.EOF) {
 692			// finished
 693			if isActiveText {
 694				isActiveText = false
 695				if !yield(ai.StreamPart{
 696					Type: ai.StreamPartTypeTextEnd,
 697					ID:   "0",
 698				}) {
 699					return
 700				}
 701			}
 702
 703			// Add logprobs if available
 704			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 705				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
 706			}
 707
 708			// Handle annotations/citations from accumulated response
 709			if len(acc.Choices) > 0 {
 710				for _, annotation := range acc.Choices[0].Message.Annotations {
 711					if annotation.Type == "url_citation" {
 712						if !yield(ai.StreamPart{
 713							Type:       ai.StreamPartTypeSource,
 714							ID:         acc.ID,
 715							SourceType: ai.SourceTypeURL,
 716							URL:        annotation.URLCitation.URL,
 717							Title:      annotation.URLCitation.Title,
 718						}) {
 719							return
 720						}
 721					}
 722				}
 723			}
 724
 725			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 726			yield(ai.StreamPart{
 727				Type:         ai.StreamPartTypeFinish,
 728				Usage:        usage,
 729				FinishReason: finishReason,
 730				ProviderMetadata: ai.ProviderMetadata{
 731					"openai": streamProviderMetadata,
 732				},
 733			})
 734			return
 735		} else {
 736			yield(ai.StreamPart{
 737				Type:  ai.StreamPartTypeError,
 738				Error: o.handleError(err),
 739			})
 740			return
 741		}
 742	}, nil
 743}
 744
 745func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 746	switch finishReason {
 747	case "stop":
 748		return ai.FinishReasonStop
 749	case "length":
 750		return ai.FinishReasonLength
 751	case "content_filter":
 752		return ai.FinishReasonContentFilter
 753	case "function_call", "tool_calls":
 754		return ai.FinishReasonToolCalls
 755	default:
 756		return ai.FinishReasonUnknown
 757	}
 758}
 759
 760func isReasoningModel(modelID string) bool {
 761	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 762}
 763
 764func isSearchPreviewModel(modelID string) bool {
 765	return strings.Contains(modelID, "search-preview")
 766}
 767
 768func supportsFlexProcessing(modelID string) bool {
 769	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 770}
 771
 772func supportsPriorityProcessing(modelID string) bool {
 773	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 774		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 775		strings.HasPrefix(modelID, "o4-mini")
 776}
 777
 778func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 779	for _, tool := range tools {
 780		if tool.GetType() == ai.ToolTypeFunction {
 781			ft, ok := tool.(ai.FunctionTool)
 782			if !ok {
 783				continue
 784			}
 785			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 786				OfFunction: &openai.ChatCompletionFunctionToolParam{
 787					Function: shared.FunctionDefinitionParam{
 788						Name:        ft.Name,
 789						Description: param.NewOpt(ft.Description),
 790						Parameters:  openai.FunctionParameters(ft.InputSchema),
 791						Strict:      param.NewOpt(false),
 792					},
 793					Type: "function",
 794				},
 795			})
 796			continue
 797		}
 798
 799		// TODO: handle provider tool calls
 800		warnings = append(warnings, ai.CallWarning{
 801			Type:    ai.CallWarningTypeUnsupportedTool,
 802			Tool:    tool,
 803			Message: "tool is not supported",
 804		})
 805	}
 806	if toolChoice == nil {
 807		return openAiTools, openAiToolChoice, warnings
 808	}
 809
 810	switch *toolChoice {
 811	case ai.ToolChoiceAuto:
 812		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 813			OfAuto: param.NewOpt("auto"),
 814		}
 815	case ai.ToolChoiceNone:
 816		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 817			OfAuto: param.NewOpt("none"),
 818		}
 819	default:
 820		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 821			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 822				Type: "function",
 823				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 824					Name: string(*toolChoice),
 825				},
 826			},
 827		}
 828	}
 829	return openAiTools, openAiToolChoice, warnings
 830}
 831
 832func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 833	var messages []openai.ChatCompletionMessageParamUnion
 834	var warnings []ai.CallWarning
 835	for _, msg := range prompt {
 836		switch msg.Role {
 837		case ai.MessageRoleSystem:
 838			var systemPromptParts []string
 839			for _, c := range msg.Content {
 840				if c.GetType() != ai.ContentTypeText {
 841					warnings = append(warnings, ai.CallWarning{
 842						Type:    ai.CallWarningTypeOther,
 843						Message: "system prompt can only have text content",
 844					})
 845					continue
 846				}
 847				textPart, ok := ai.AsContentType[ai.TextPart](c)
 848				if !ok {
 849					warnings = append(warnings, ai.CallWarning{
 850						Type:    ai.CallWarningTypeOther,
 851						Message: "system prompt text part does not have the right type",
 852					})
 853					continue
 854				}
 855				text := textPart.Text
 856				if strings.TrimSpace(text) != "" {
 857					systemPromptParts = append(systemPromptParts, textPart.Text)
 858				}
 859			}
 860			if len(systemPromptParts) == 0 {
 861				warnings = append(warnings, ai.CallWarning{
 862					Type:    ai.CallWarningTypeOther,
 863					Message: "system prompt has no text parts",
 864				})
 865				continue
 866			}
 867			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 868		case ai.MessageRoleUser:
 869			// simple user message just text content
 870			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 871				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 872				if !ok {
 873					warnings = append(warnings, ai.CallWarning{
 874						Type:    ai.CallWarningTypeOther,
 875						Message: "user message text part does not have the right type",
 876					})
 877					continue
 878				}
 879				messages = append(messages, openai.UserMessage(textPart.Text))
 880				continue
 881			}
 882			// text content and attachments
 883			// for now we only support image content later we need to check
 884			// TODO: add the supported media types to the language model so we
 885			//  can use that to validate the data here.
 886			var content []openai.ChatCompletionContentPartUnionParam
 887			for _, c := range msg.Content {
 888				switch c.GetType() {
 889				case ai.ContentTypeText:
 890					textPart, ok := ai.AsContentType[ai.TextPart](c)
 891					if !ok {
 892						warnings = append(warnings, ai.CallWarning{
 893							Type:    ai.CallWarningTypeOther,
 894							Message: "user message text part does not have the right type",
 895						})
 896						continue
 897					}
 898					content = append(content, openai.ChatCompletionContentPartUnionParam{
 899						OfText: &openai.ChatCompletionContentPartTextParam{
 900							Text: textPart.Text,
 901						},
 902					})
 903				case ai.ContentTypeFile:
 904					filePart, ok := ai.AsContentType[ai.FilePart](c)
 905					if !ok {
 906						warnings = append(warnings, ai.CallWarning{
 907							Type:    ai.CallWarningTypeOther,
 908							Message: "user message file part does not have the right type",
 909						})
 910						continue
 911					}
 912
 913					switch {
 914					case strings.HasPrefix(filePart.MediaType, "image/"):
 915						// Handle image files
 916						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 917						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 918						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 919
 920						// Check for provider-specific options like image detail
 921						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
 922							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
 923								imageURL.Detail = detail.ImageDetail
 924							}
 925						}
 926
 927						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 928						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 929
 930					case filePart.MediaType == "audio/wav":
 931						// Handle WAV audio files
 932						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 933						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 934							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 935								Data:   base64Encoded,
 936								Format: "wav",
 937							},
 938						}
 939						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 940
 941					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 942						// Handle MP3 audio files
 943						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 944						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 945							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 946								Data:   base64Encoded,
 947								Format: "mp3",
 948							},
 949						}
 950						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 951
 952					case filePart.MediaType == "application/pdf":
 953						// Handle PDF files
 954						dataStr := string(filePart.Data)
 955
 956						// Check if data looks like a file ID (starts with "file-")
 957						if strings.HasPrefix(dataStr, "file-") {
 958							fileBlock := openai.ChatCompletionContentPartFileParam{
 959								File: openai.ChatCompletionContentPartFileFileParam{
 960									FileID: param.NewOpt(dataStr),
 961								},
 962							}
 963							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 964						} else {
 965							// Handle as base64 data
 966							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 967							data := "data:application/pdf;base64," + base64Encoded
 968
 969							filename := filePart.Filename
 970							if filename == "" {
 971								// Generate default filename based on content index
 972								filename = fmt.Sprintf("part-%d.pdf", len(content))
 973							}
 974
 975							fileBlock := openai.ChatCompletionContentPartFileParam{
 976								File: openai.ChatCompletionContentPartFileFileParam{
 977									Filename: param.NewOpt(filename),
 978									FileData: param.NewOpt(data),
 979								},
 980							}
 981							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 982						}
 983
 984					default:
 985						warnings = append(warnings, ai.CallWarning{
 986							Type:    ai.CallWarningTypeOther,
 987							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
 988						})
 989					}
 990				}
 991			}
 992			messages = append(messages, openai.UserMessage(content))
 993		case ai.MessageRoleAssistant:
 994			// simple assistant message just text content
 995			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 996				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 997				if !ok {
 998					warnings = append(warnings, ai.CallWarning{
 999						Type:    ai.CallWarningTypeOther,
1000						Message: "assistant message text part does not have the right type",
1001					})
1002					continue
1003				}
1004				messages = append(messages, openai.AssistantMessage(textPart.Text))
1005				continue
1006			}
1007			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1008				Role: "assistant",
1009			}
1010			for _, c := range msg.Content {
1011				switch c.GetType() {
1012				case ai.ContentTypeText:
1013					textPart, ok := ai.AsContentType[ai.TextPart](c)
1014					if !ok {
1015						warnings = append(warnings, ai.CallWarning{
1016							Type:    ai.CallWarningTypeOther,
1017							Message: "assistant message text part does not have the right type",
1018						})
1019						continue
1020					}
1021					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1022						OfString: param.NewOpt(textPart.Text),
1023					}
1024				case ai.ContentTypeToolCall:
1025					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1026					if !ok {
1027						warnings = append(warnings, ai.CallWarning{
1028							Type:    ai.CallWarningTypeOther,
1029							Message: "assistant message tool part does not have the right type",
1030						})
1031						continue
1032					}
1033					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1034						openai.ChatCompletionMessageToolCallUnionParam{
1035							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1036								ID:   toolCallPart.ToolCallID,
1037								Type: "function",
1038								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1039									Name:      toolCallPart.ToolName,
1040									Arguments: toolCallPart.Input,
1041								},
1042							},
1043						})
1044				}
1045			}
1046			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1047				OfAssistant: &assistantMsg,
1048			})
1049		case ai.MessageRoleTool:
1050			for _, c := range msg.Content {
1051				if c.GetType() != ai.ContentTypeToolResult {
1052					warnings = append(warnings, ai.CallWarning{
1053						Type:    ai.CallWarningTypeOther,
1054						Message: "tool message can only have tool result content",
1055					})
1056					continue
1057				}
1058
1059				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1060				if !ok {
1061					warnings = append(warnings, ai.CallWarning{
1062						Type:    ai.CallWarningTypeOther,
1063						Message: "tool message result part does not have the right type",
1064					})
1065					continue
1066				}
1067
1068				switch toolResultPart.Output.GetType() {
1069				case ai.ToolResultContentTypeText:
1070					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1071					if !ok {
1072						warnings = append(warnings, ai.CallWarning{
1073							Type:    ai.CallWarningTypeOther,
1074							Message: "tool result output does not have the right type",
1075						})
1076						continue
1077					}
1078					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1079				case ai.ToolResultContentTypeError:
1080					// TODO: check if better handling is needed
1081					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1082					if !ok {
1083						warnings = append(warnings, ai.CallWarning{
1084							Type:    ai.CallWarningTypeOther,
1085							Message: "tool result output does not have the right type",
1086						})
1087						continue
1088					}
1089					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1090				}
1091			}
1092		}
1093	}
1094	return messages, warnings
1095}
1096
1097// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1098func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1099	var annotations []openai.ChatCompletionMessageAnnotation
1100
1101	// Parse the raw JSON to extract annotations
1102	var deltaData map[string]any
1103	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1104		return annotations
1105	}
1106
1107	// Check if annotations exist in the delta
1108	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1109		for _, annotationData := range annotationsData {
1110			if annotationMap, ok := annotationData.(map[string]any); ok {
1111				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1112					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1113						annotation := openai.ChatCompletionMessageAnnotation{
1114							Type: "url_citation",
1115							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1116								URL:   urlCitationData["url"].(string),
1117								Title: urlCitationData["title"].(string),
1118							},
1119						}
1120						annotations = append(annotations, annotation)
1121					}
1122				}
1123			}
1124		}
1125	}
1126
1127	return annotations
1128}