openai.go

   1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/fantasy/ai"
  15	xjson "github.com/charmbracelet/x/json"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23const (
  24	Name       = "openai"
  25	DefaultURL = "https://api.openai.com/v1"
  26)
  27
  28type provider struct {
  29	options options
  30}
  31
  32type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
  33
  34type Hooks struct {
  35	PrepareCallWithOptions PrepareCallWithOptions
  36}
  37
  38type options struct {
  39	baseURL      string
  40	apiKey       string
  41	organization string
  42	project      string
  43	name         string
  44	hooks        Hooks
  45	headers      map[string]string
  46	client       option.HTTPClient
  47}
  48
  49type Option = func(*options)
  50
  51func New(opts ...Option) ai.Provider {
  52	providerOptions := options{
  53		headers: map[string]string{},
  54	}
  55	for _, o := range opts {
  56		o(&providerOptions)
  57	}
  58
  59	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
  60	providerOptions.name = cmp.Or(providerOptions.name, Name)
  61
  62	if providerOptions.organization != "" {
  63		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
  64	}
  65	if providerOptions.project != "" {
  66		providerOptions.headers["OpenAi-Project"] = providerOptions.project
  67	}
  68
  69	return &provider{options: providerOptions}
  70}
  71
  72func WithBaseURL(baseURL string) Option {
  73	return func(o *options) {
  74		o.baseURL = baseURL
  75	}
  76}
  77
  78func WithAPIKey(apiKey string) Option {
  79	return func(o *options) {
  80		o.apiKey = apiKey
  81	}
  82}
  83
  84func WithOrganization(organization string) Option {
  85	return func(o *options) {
  86		o.organization = organization
  87	}
  88}
  89
  90func WithProject(project string) Option {
  91	return func(o *options) {
  92		o.project = project
  93	}
  94}
  95
  96func WithName(name string) Option {
  97	return func(o *options) {
  98		o.name = name
  99	}
 100}
 101
 102func WithHeaders(headers map[string]string) Option {
 103	return func(o *options) {
 104		maps.Copy(o.headers, headers)
 105	}
 106}
 107
 108func WithHTTPClient(client option.HTTPClient) Option {
 109	return func(o *options) {
 110		o.client = client
 111	}
 112}
 113
 114func WithHooks(hooks Hooks) Option {
 115	return func(o *options) {
 116		o.hooks = hooks
 117	}
 118}
 119
 120// LanguageModel implements ai.Provider.
 121func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 122	openaiClientOptions := []option.RequestOption{}
 123	if o.options.apiKey != "" {
 124		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 125	}
 126	if o.options.baseURL != "" {
 127		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 128	}
 129
 130	for key, value := range o.options.headers {
 131		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 132	}
 133
 134	if o.options.client != nil {
 135		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 136	}
 137
 138	return languageModel{
 139		modelID:  modelID,
 140		provider: o.options.name,
 141		options:  o.options,
 142		client:   openai.NewClient(openaiClientOptions...),
 143	}, nil
 144}
 145
 146type languageModel struct {
 147	provider string
 148	modelID  string
 149	client   openai.Client
 150	options  options
 151}
 152
 153// Model implements ai.LanguageModel.
 154func (o languageModel) Model() string {
 155	return o.modelID
 156}
 157
 158// Provider implements ai.LanguageModel.
 159func (o languageModel) Provider() string {
 160	return o.provider
 161}
 162
 163func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 164	if call.ProviderOptions == nil {
 165		return nil, nil
 166	}
 167	var warnings []ai.CallWarning
 168	providerOptions := &ProviderOptions{}
 169	if v, ok := call.ProviderOptions[Name]; ok {
 170		providerOptions, ok = v.(*ProviderOptions)
 171		if !ok {
 172			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 173		}
 174	}
 175
 176	if providerOptions.LogitBias != nil {
 177		params.LogitBias = providerOptions.LogitBias
 178	}
 179	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 180		providerOptions.LogProbs = nil
 181	}
 182	if providerOptions.LogProbs != nil {
 183		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 184	}
 185	if providerOptions.TopLogProbs != nil {
 186		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 187	}
 188	if providerOptions.User != nil {
 189		params.User = param.NewOpt(*providerOptions.User)
 190	}
 191	if providerOptions.ParallelToolCalls != nil {
 192		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 193	}
 194	if providerOptions.MaxCompletionTokens != nil {
 195		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 196	}
 197
 198	if providerOptions.TextVerbosity != nil {
 199		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 200	}
 201	if providerOptions.Prediction != nil {
 202		// Convert map[string]any to ChatCompletionPredictionContentParam
 203		if content, ok := providerOptions.Prediction["content"]; ok {
 204			if contentStr, ok := content.(string); ok {
 205				params.Prediction = openai.ChatCompletionPredictionContentParam{
 206					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 207						OfString: param.NewOpt(contentStr),
 208					},
 209				}
 210			}
 211		}
 212	}
 213	if providerOptions.Store != nil {
 214		params.Store = param.NewOpt(*providerOptions.Store)
 215	}
 216	if providerOptions.Metadata != nil {
 217		// Convert map[string]any to map[string]string
 218		metadata := make(map[string]string)
 219		for k, v := range providerOptions.Metadata {
 220			if str, ok := v.(string); ok {
 221				metadata[k] = str
 222			}
 223		}
 224		params.Metadata = metadata
 225	}
 226	if providerOptions.PromptCacheKey != nil {
 227		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 228	}
 229	if providerOptions.SafetyIdentifier != nil {
 230		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 231	}
 232	if providerOptions.ServiceTier != nil {
 233		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 234	}
 235
 236	if providerOptions.ReasoningEffort != nil {
 237		switch *providerOptions.ReasoningEffort {
 238		case ReasoningEffortMinimal:
 239			params.ReasoningEffort = shared.ReasoningEffortMinimal
 240		case ReasoningEffortLow:
 241			params.ReasoningEffort = shared.ReasoningEffortLow
 242		case ReasoningEffortMedium:
 243			params.ReasoningEffort = shared.ReasoningEffortMedium
 244		case ReasoningEffortHigh:
 245			params.ReasoningEffort = shared.ReasoningEffortHigh
 246		default:
 247			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 248		}
 249	}
 250
 251	if isReasoningModel(model.Model()) {
 252		if providerOptions.LogitBias != nil {
 253			params.LogitBias = nil
 254			warnings = append(warnings, ai.CallWarning{
 255				Type:    ai.CallWarningTypeUnsupportedSetting,
 256				Setting: "LogitBias",
 257				Message: "LogitBias is not supported for reasoning models",
 258			})
 259		}
 260		if providerOptions.LogProbs != nil {
 261			params.Logprobs = param.Opt[bool]{}
 262			warnings = append(warnings, ai.CallWarning{
 263				Type:    ai.CallWarningTypeUnsupportedSetting,
 264				Setting: "Logprobs",
 265				Message: "Logprobs is not supported for reasoning models",
 266			})
 267		}
 268		if providerOptions.TopLogProbs != nil {
 269			params.TopLogprobs = param.Opt[int64]{}
 270			warnings = append(warnings, ai.CallWarning{
 271				Type:    ai.CallWarningTypeUnsupportedSetting,
 272				Setting: "TopLogprobs",
 273				Message: "TopLogprobs is not supported for reasoning models",
 274			})
 275		}
 276	}
 277
 278	// Handle service tier validation
 279	if providerOptions.ServiceTier != nil {
 280		serviceTier := *providerOptions.ServiceTier
 281		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
 282			params.ServiceTier = ""
 283			warnings = append(warnings, ai.CallWarning{
 284				Type:    ai.CallWarningTypeUnsupportedSetting,
 285				Setting: "ServiceTier",
 286				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 287			})
 288		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
 289			params.ServiceTier = ""
 290			warnings = append(warnings, ai.CallWarning{
 291				Type:    ai.CallWarningTypeUnsupportedSetting,
 292				Setting: "ServiceTier",
 293				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 294			})
 295		}
 296	}
 297	return warnings, nil
 298}
 299
 300func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 301	params := &openai.ChatCompletionNewParams{}
 302	messages, warnings := toPrompt(call.Prompt)
 303	if call.TopK != nil {
 304		warnings = append(warnings, ai.CallWarning{
 305			Type:    ai.CallWarningTypeUnsupportedSetting,
 306			Setting: "top_k",
 307		})
 308	}
 309	params.Messages = messages
 310	params.Model = o.modelID
 311
 312	if call.MaxOutputTokens != nil {
 313		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 314	}
 315	if call.Temperature != nil {
 316		params.Temperature = param.NewOpt(*call.Temperature)
 317	}
 318	if call.TopP != nil {
 319		params.TopP = param.NewOpt(*call.TopP)
 320	}
 321	if call.FrequencyPenalty != nil {
 322		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 323	}
 324	if call.PresencePenalty != nil {
 325		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 326	}
 327
 328	if isReasoningModel(o.modelID) {
 329		// remove unsupported settings for reasoning models
 330		// see https://platform.openai.com/docs/guides/reasoning#limitations
 331		if call.Temperature != nil {
 332			params.Temperature = param.Opt[float64]{}
 333			warnings = append(warnings, ai.CallWarning{
 334				Type:    ai.CallWarningTypeUnsupportedSetting,
 335				Setting: "temperature",
 336				Details: "temperature is not supported for reasoning models",
 337			})
 338		}
 339		if call.TopP != nil {
 340			params.TopP = param.Opt[float64]{}
 341			warnings = append(warnings, ai.CallWarning{
 342				Type:    ai.CallWarningTypeUnsupportedSetting,
 343				Setting: "TopP",
 344				Details: "TopP is not supported for reasoning models",
 345			})
 346		}
 347		if call.FrequencyPenalty != nil {
 348			params.FrequencyPenalty = param.Opt[float64]{}
 349			warnings = append(warnings, ai.CallWarning{
 350				Type:    ai.CallWarningTypeUnsupportedSetting,
 351				Setting: "FrequencyPenalty",
 352				Details: "FrequencyPenalty is not supported for reasoning models",
 353			})
 354		}
 355		if call.PresencePenalty != nil {
 356			params.PresencePenalty = param.Opt[float64]{}
 357			warnings = append(warnings, ai.CallWarning{
 358				Type:    ai.CallWarningTypeUnsupportedSetting,
 359				Setting: "PresencePenalty",
 360				Details: "PresencePenalty is not supported for reasoning models",
 361			})
 362		}
 363
 364		// reasoning models use max_completion_tokens instead of max_tokens
 365		if call.MaxOutputTokens != nil {
 366			if !params.MaxCompletionTokens.Valid() {
 367				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 368			}
 369			params.MaxTokens = param.Opt[int64]{}
 370		}
 371	}
 372
 373	// Handle search preview models
 374	if isSearchPreviewModel(o.modelID) {
 375		if call.Temperature != nil {
 376			params.Temperature = param.Opt[float64]{}
 377			warnings = append(warnings, ai.CallWarning{
 378				Type:    ai.CallWarningTypeUnsupportedSetting,
 379				Setting: "temperature",
 380				Details: "temperature is not supported for the search preview models and has been removed.",
 381			})
 382		}
 383	}
 384
 385	prepareOptions := prepareCallWithOptions
 386	if o.options.hooks.PrepareCallWithOptions != nil {
 387		prepareOptions = o.options.hooks.PrepareCallWithOptions
 388	}
 389
 390	optionsWarnings, err := prepareOptions(o, params, call)
 391	if err != nil {
 392		return nil, nil, err
 393	}
 394
 395	if len(optionsWarnings) > 0 {
 396		warnings = append(warnings, optionsWarnings...)
 397	}
 398
 399	if len(call.Tools) > 0 {
 400		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 401		params.Tools = tools
 402		if toolChoice != nil {
 403			params.ToolChoice = *toolChoice
 404		}
 405		warnings = append(warnings, toolWarnings...)
 406	}
 407	return params, warnings, nil
 408}
 409
 410func (o languageModel) handleError(err error) error {
 411	var apiErr *openai.Error
 412	if errors.As(err, &apiErr) {
 413		requestDump := apiErr.DumpRequest(true)
 414		responseDump := apiErr.DumpResponse(true)
 415		headers := map[string]string{}
 416		for k, h := range apiErr.Response.Header {
 417			v := h[len(h)-1]
 418			headers[strings.ToLower(k)] = v
 419		}
 420		return ai.NewAPICallError(
 421			apiErr.Message,
 422			apiErr.Request.URL.String(),
 423			string(requestDump),
 424			apiErr.StatusCode,
 425			headers,
 426			string(responseDump),
 427			apiErr,
 428			false,
 429		)
 430	}
 431	return err
 432}
 433
 434// Generate implements ai.LanguageModel.
 435func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 436	params, warnings, err := o.prepareParams(call)
 437	if err != nil {
 438		return nil, err
 439	}
 440	response, err := o.client.Chat.Completions.New(ctx, *params)
 441	if err != nil {
 442		return nil, o.handleError(err)
 443	}
 444
 445	if len(response.Choices) == 0 {
 446		return nil, errors.New("no response generated")
 447	}
 448	choice := response.Choices[0]
 449	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 450	text := choice.Message.Content
 451	if text != "" {
 452		content = append(content, ai.TextContent{
 453			Text: text,
 454		})
 455	}
 456
 457	for _, tc := range choice.Message.ToolCalls {
 458		toolCallID := tc.ID
 459		if toolCallID == "" {
 460			toolCallID = uuid.NewString()
 461		}
 462		content = append(content, ai.ToolCallContent{
 463			ProviderExecuted: false, // TODO: update when handling other tools
 464			ToolCallID:       toolCallID,
 465			ToolName:         tc.Function.Name,
 466			Input:            tc.Function.Arguments,
 467		})
 468	}
 469	// Handle annotations/citations
 470	for _, annotation := range choice.Message.Annotations {
 471		if annotation.Type == "url_citation" {
 472			content = append(content, ai.SourceContent{
 473				SourceType: ai.SourceTypeURL,
 474				ID:         uuid.NewString(),
 475				URL:        annotation.URLCitation.URL,
 476				Title:      annotation.URLCitation.Title,
 477			})
 478		}
 479	}
 480
 481	completionTokenDetails := response.Usage.CompletionTokensDetails
 482	promptTokenDetails := response.Usage.PromptTokensDetails
 483
 484	// Build provider metadata
 485	providerMetadata := &ProviderMetadata{}
 486	// Add logprobs if available
 487	if len(choice.Logprobs.Content) > 0 {
 488		providerMetadata.Logprobs = choice.Logprobs.Content
 489	}
 490
 491	// Add prediction tokens if available
 492	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 493		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 494			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 495		}
 496		if completionTokenDetails.RejectedPredictionTokens > 0 {
 497			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 498		}
 499	}
 500
 501	return &ai.Response{
 502		Content: content,
 503		Usage: ai.Usage{
 504			InputTokens:     response.Usage.PromptTokens,
 505			OutputTokens:    response.Usage.CompletionTokens,
 506			TotalTokens:     response.Usage.TotalTokens,
 507			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 508			CacheReadTokens: promptTokenDetails.CachedTokens,
 509		},
 510		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
 511		ProviderMetadata: ai.ProviderMetadata{
 512			Name: providerMetadata,
 513		},
 514		Warnings: warnings,
 515	}, nil
 516}
 517
 518type toolCall struct {
 519	id          string
 520	name        string
 521	arguments   string
 522	hasFinished bool
 523}
 524
 525// Stream implements ai.LanguageModel.
 526func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 527	params, warnings, err := o.prepareParams(call)
 528	if err != nil {
 529		return nil, err
 530	}
 531
 532	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 533		IncludeUsage: openai.Bool(true),
 534	}
 535
 536	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 537	isActiveText := false
 538	toolCalls := make(map[int64]toolCall)
 539
 540	// Build provider metadata for streaming
 541	streamProviderMetadata := &ProviderMetadata{}
 542	acc := openai.ChatCompletionAccumulator{}
 543	var usage ai.Usage
 544	return func(yield func(ai.StreamPart) bool) {
 545		if len(warnings) > 0 {
 546			if !yield(ai.StreamPart{
 547				Type:     ai.StreamPartTypeWarnings,
 548				Warnings: warnings,
 549			}) {
 550				return
 551			}
 552		}
 553		for stream.Next() {
 554			chunk := stream.Current()
 555			acc.AddChunk(chunk)
 556			if chunk.Usage.TotalTokens > 0 {
 557				// we do this here because the acc does not add prompt details
 558				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 559				promptTokenDetails := chunk.Usage.PromptTokensDetails
 560				usage = ai.Usage{
 561					InputTokens:     chunk.Usage.PromptTokens,
 562					OutputTokens:    chunk.Usage.CompletionTokens,
 563					TotalTokens:     chunk.Usage.TotalTokens,
 564					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 565					CacheReadTokens: promptTokenDetails.CachedTokens,
 566				}
 567
 568				// Add prediction tokens if available
 569				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 570					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 571						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 572					}
 573					if completionTokenDetails.RejectedPredictionTokens > 0 {
 574						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 575					}
 576				}
 577			}
 578			if len(chunk.Choices) == 0 {
 579				continue
 580			}
 581			for _, choice := range chunk.Choices {
 582				switch {
 583				case choice.Delta.Content != "":
 584					if !isActiveText {
 585						isActiveText = true
 586						if !yield(ai.StreamPart{
 587							Type: ai.StreamPartTypeTextStart,
 588							ID:   "0",
 589						}) {
 590							return
 591						}
 592					}
 593					if !yield(ai.StreamPart{
 594						Type:  ai.StreamPartTypeTextDelta,
 595						ID:    "0",
 596						Delta: choice.Delta.Content,
 597					}) {
 598						return
 599					}
 600				case len(choice.Delta.ToolCalls) > 0:
 601					if isActiveText {
 602						isActiveText = false
 603						if !yield(ai.StreamPart{
 604							Type: ai.StreamPartTypeTextEnd,
 605							ID:   "0",
 606						}) {
 607							return
 608						}
 609					}
 610
 611					for _, toolCallDelta := range choice.Delta.ToolCalls {
 612						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 613							if existingToolCall.hasFinished {
 614								continue
 615							}
 616							if toolCallDelta.Function.Arguments != "" {
 617								existingToolCall.arguments += toolCallDelta.Function.Arguments
 618							}
 619							if !yield(ai.StreamPart{
 620								Type:  ai.StreamPartTypeToolInputDelta,
 621								ID:    existingToolCall.id,
 622								Delta: toolCallDelta.Function.Arguments,
 623							}) {
 624								return
 625							}
 626							toolCalls[toolCallDelta.Index] = existingToolCall
 627							if xjson.IsValid(existingToolCall.arguments) {
 628								if !yield(ai.StreamPart{
 629									Type: ai.StreamPartTypeToolInputEnd,
 630									ID:   existingToolCall.id,
 631								}) {
 632									return
 633								}
 634
 635								if !yield(ai.StreamPart{
 636									Type:          ai.StreamPartTypeToolCall,
 637									ID:            existingToolCall.id,
 638									ToolCallName:  existingToolCall.name,
 639									ToolCallInput: existingToolCall.arguments,
 640								}) {
 641									return
 642								}
 643								existingToolCall.hasFinished = true
 644								toolCalls[toolCallDelta.Index] = existingToolCall
 645							}
 646						} else {
 647							// Does not exist
 648							var err error
 649							if toolCallDelta.Type != "function" {
 650								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 651							}
 652							if toolCallDelta.ID == "" {
 653								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 654							}
 655							if toolCallDelta.Function.Name == "" {
 656								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 657							}
 658							if err != nil {
 659								yield(ai.StreamPart{
 660									Type:  ai.StreamPartTypeError,
 661									Error: o.handleError(stream.Err()),
 662								})
 663								return
 664							}
 665
 666							if !yield(ai.StreamPart{
 667								Type:         ai.StreamPartTypeToolInputStart,
 668								ID:           toolCallDelta.ID,
 669								ToolCallName: toolCallDelta.Function.Name,
 670							}) {
 671								return
 672							}
 673							toolCalls[toolCallDelta.Index] = toolCall{
 674								id:        toolCallDelta.ID,
 675								name:      toolCallDelta.Function.Name,
 676								arguments: toolCallDelta.Function.Arguments,
 677							}
 678
 679							exTc := toolCalls[toolCallDelta.Index]
 680							if exTc.arguments != "" {
 681								if !yield(ai.StreamPart{
 682									Type:  ai.StreamPartTypeToolInputDelta,
 683									ID:    exTc.id,
 684									Delta: exTc.arguments,
 685								}) {
 686									return
 687								}
 688								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
 689									if !yield(ai.StreamPart{
 690										Type: ai.StreamPartTypeToolInputEnd,
 691										ID:   toolCallDelta.ID,
 692									}) {
 693										return
 694									}
 695
 696									if !yield(ai.StreamPart{
 697										Type:          ai.StreamPartTypeToolCall,
 698										ID:            exTc.id,
 699										ToolCallName:  exTc.name,
 700										ToolCallInput: exTc.arguments,
 701									}) {
 702										return
 703									}
 704									exTc.hasFinished = true
 705									toolCalls[toolCallDelta.Index] = exTc
 706								}
 707							}
 708							continue
 709						}
 710					}
 711				}
 712			}
 713
 714			// Check for annotations in the delta's raw JSON
 715			for _, choice := range chunk.Choices {
 716				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 717					for _, annotation := range annotations {
 718						if annotation.Type == "url_citation" {
 719							if !yield(ai.StreamPart{
 720								Type:       ai.StreamPartTypeSource,
 721								ID:         uuid.NewString(),
 722								SourceType: ai.SourceTypeURL,
 723								URL:        annotation.URLCitation.URL,
 724								Title:      annotation.URLCitation.Title,
 725							}) {
 726								return
 727							}
 728						}
 729					}
 730				}
 731			}
 732		}
 733		err := stream.Err()
 734		if err == nil || errors.Is(err, io.EOF) {
 735			// finished
 736			if isActiveText {
 737				isActiveText = false
 738				if !yield(ai.StreamPart{
 739					Type: ai.StreamPartTypeTextEnd,
 740					ID:   "0",
 741				}) {
 742					return
 743				}
 744			}
 745
 746			// Add logprobs if available
 747			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 748				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
 749			}
 750
 751			// Handle annotations/citations from accumulated response
 752			if len(acc.Choices) > 0 {
 753				for _, annotation := range acc.Choices[0].Message.Annotations {
 754					if annotation.Type == "url_citation" {
 755						if !yield(ai.StreamPart{
 756							Type:       ai.StreamPartTypeSource,
 757							ID:         acc.ID,
 758							SourceType: ai.SourceTypeURL,
 759							URL:        annotation.URLCitation.URL,
 760							Title:      annotation.URLCitation.Title,
 761						}) {
 762							return
 763						}
 764					}
 765				}
 766			}
 767
 768			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 769			yield(ai.StreamPart{
 770				Type:         ai.StreamPartTypeFinish,
 771				Usage:        usage,
 772				FinishReason: finishReason,
 773				ProviderMetadata: ai.ProviderMetadata{
 774					Name: streamProviderMetadata,
 775				},
 776			})
 777			return
 778		} else {
 779			yield(ai.StreamPart{
 780				Type:  ai.StreamPartTypeError,
 781				Error: o.handleError(err),
 782			})
 783			return
 784		}
 785	}, nil
 786}
 787
 788func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
 789	var options ProviderOptions
 790	if err := ai.ParseOptions(data, &options); err != nil {
 791		return nil, err
 792	}
 793	return &options, nil
 794}
 795
 796func (o *provider) Name() string {
 797	return Name
 798}
 799
 800func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 801	switch finishReason {
 802	case "stop":
 803		return ai.FinishReasonStop
 804	case "length":
 805		return ai.FinishReasonLength
 806	case "content_filter":
 807		return ai.FinishReasonContentFilter
 808	case "function_call", "tool_calls":
 809		return ai.FinishReasonToolCalls
 810	default:
 811		return ai.FinishReasonUnknown
 812	}
 813}
 814
 815func isReasoningModel(modelID string) bool {
 816	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 817}
 818
 819func isSearchPreviewModel(modelID string) bool {
 820	return strings.Contains(modelID, "search-preview")
 821}
 822
 823func supportsFlexProcessing(modelID string) bool {
 824	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 825}
 826
 827func supportsPriorityProcessing(modelID string) bool {
 828	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 829		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 830		strings.HasPrefix(modelID, "o4-mini")
 831}
 832
 833func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 834	for _, tool := range tools {
 835		if tool.GetType() == ai.ToolTypeFunction {
 836			ft, ok := tool.(ai.FunctionTool)
 837			if !ok {
 838				continue
 839			}
 840			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 841				OfFunction: &openai.ChatCompletionFunctionToolParam{
 842					Function: shared.FunctionDefinitionParam{
 843						Name:        ft.Name,
 844						Description: param.NewOpt(ft.Description),
 845						Parameters:  openai.FunctionParameters(ft.InputSchema),
 846						Strict:      param.NewOpt(false),
 847					},
 848					Type: "function",
 849				},
 850			})
 851			continue
 852		}
 853
 854		// TODO: handle provider tool calls
 855		warnings = append(warnings, ai.CallWarning{
 856			Type:    ai.CallWarningTypeUnsupportedTool,
 857			Tool:    tool,
 858			Message: "tool is not supported",
 859		})
 860	}
 861	if toolChoice == nil {
 862		return openAiTools, openAiToolChoice, warnings
 863	}
 864
 865	switch *toolChoice {
 866	case ai.ToolChoiceAuto:
 867		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 868			OfAuto: param.NewOpt("auto"),
 869		}
 870	case ai.ToolChoiceNone:
 871		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 872			OfAuto: param.NewOpt("none"),
 873		}
 874	default:
 875		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 876			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 877				Type: "function",
 878				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 879					Name: string(*toolChoice),
 880				},
 881			},
 882		}
 883	}
 884	return openAiTools, openAiToolChoice, warnings
 885}
 886
 887func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 888	var messages []openai.ChatCompletionMessageParamUnion
 889	var warnings []ai.CallWarning
 890	for _, msg := range prompt {
 891		switch msg.Role {
 892		case ai.MessageRoleSystem:
 893			var systemPromptParts []string
 894			for _, c := range msg.Content {
 895				if c.GetType() != ai.ContentTypeText {
 896					warnings = append(warnings, ai.CallWarning{
 897						Type:    ai.CallWarningTypeOther,
 898						Message: "system prompt can only have text content",
 899					})
 900					continue
 901				}
 902				textPart, ok := ai.AsContentType[ai.TextPart](c)
 903				if !ok {
 904					warnings = append(warnings, ai.CallWarning{
 905						Type:    ai.CallWarningTypeOther,
 906						Message: "system prompt text part does not have the right type",
 907					})
 908					continue
 909				}
 910				text := textPart.Text
 911				if strings.TrimSpace(text) != "" {
 912					systemPromptParts = append(systemPromptParts, textPart.Text)
 913				}
 914			}
 915			if len(systemPromptParts) == 0 {
 916				warnings = append(warnings, ai.CallWarning{
 917					Type:    ai.CallWarningTypeOther,
 918					Message: "system prompt has no text parts",
 919				})
 920				continue
 921			}
 922			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 923		case ai.MessageRoleUser:
 924			// simple user message just text content
 925			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 926				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 927				if !ok {
 928					warnings = append(warnings, ai.CallWarning{
 929						Type:    ai.CallWarningTypeOther,
 930						Message: "user message text part does not have the right type",
 931					})
 932					continue
 933				}
 934				messages = append(messages, openai.UserMessage(textPart.Text))
 935				continue
 936			}
 937			// text content and attachments
 938			// for now we only support image content later we need to check
 939			// TODO: add the supported media types to the language model so we
 940			//  can use that to validate the data here.
 941			var content []openai.ChatCompletionContentPartUnionParam
 942			for _, c := range msg.Content {
 943				switch c.GetType() {
 944				case ai.ContentTypeText:
 945					textPart, ok := ai.AsContentType[ai.TextPart](c)
 946					if !ok {
 947						warnings = append(warnings, ai.CallWarning{
 948							Type:    ai.CallWarningTypeOther,
 949							Message: "user message text part does not have the right type",
 950						})
 951						continue
 952					}
 953					content = append(content, openai.ChatCompletionContentPartUnionParam{
 954						OfText: &openai.ChatCompletionContentPartTextParam{
 955							Text: textPart.Text,
 956						},
 957					})
 958				case ai.ContentTypeFile:
 959					filePart, ok := ai.AsContentType[ai.FilePart](c)
 960					if !ok {
 961						warnings = append(warnings, ai.CallWarning{
 962							Type:    ai.CallWarningTypeOther,
 963							Message: "user message file part does not have the right type",
 964						})
 965						continue
 966					}
 967
 968					switch {
 969					case strings.HasPrefix(filePart.MediaType, "image/"):
 970						// Handle image files
 971						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 972						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 973						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 974
 975						// Check for provider-specific options like image detail
 976						if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
 977							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
 978								imageURL.Detail = detail.ImageDetail
 979							}
 980						}
 981
 982						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 983						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 984
 985					case filePart.MediaType == "audio/wav":
 986						// Handle WAV audio files
 987						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 988						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 989							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 990								Data:   base64Encoded,
 991								Format: "wav",
 992							},
 993						}
 994						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 995
 996					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 997						// Handle MP3 audio files
 998						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 999						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
1000							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
1001								Data:   base64Encoded,
1002								Format: "mp3",
1003							},
1004						}
1005						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
1006
1007					case filePart.MediaType == "application/pdf":
1008						// Handle PDF files
1009						dataStr := string(filePart.Data)
1010
1011						// Check if data looks like a file ID (starts with "file-")
1012						if strings.HasPrefix(dataStr, "file-") {
1013							fileBlock := openai.ChatCompletionContentPartFileParam{
1014								File: openai.ChatCompletionContentPartFileFileParam{
1015									FileID: param.NewOpt(dataStr),
1016								},
1017							}
1018							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1019						} else {
1020							// Handle as base64 data
1021							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1022							data := "data:application/pdf;base64," + base64Encoded
1023
1024							filename := filePart.Filename
1025							if filename == "" {
1026								// Generate default filename based on content index
1027								filename = fmt.Sprintf("part-%d.pdf", len(content))
1028							}
1029
1030							fileBlock := openai.ChatCompletionContentPartFileParam{
1031								File: openai.ChatCompletionContentPartFileFileParam{
1032									Filename: param.NewOpt(filename),
1033									FileData: param.NewOpt(data),
1034								},
1035							}
1036							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1037						}
1038
1039					default:
1040						warnings = append(warnings, ai.CallWarning{
1041							Type:    ai.CallWarningTypeOther,
1042							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1043						})
1044					}
1045				}
1046			}
1047			messages = append(messages, openai.UserMessage(content))
1048		case ai.MessageRoleAssistant:
1049			// simple assistant message just text content
1050			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1051				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1052				if !ok {
1053					warnings = append(warnings, ai.CallWarning{
1054						Type:    ai.CallWarningTypeOther,
1055						Message: "assistant message text part does not have the right type",
1056					})
1057					continue
1058				}
1059				messages = append(messages, openai.AssistantMessage(textPart.Text))
1060				continue
1061			}
1062			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1063				Role: "assistant",
1064			}
1065			for _, c := range msg.Content {
1066				switch c.GetType() {
1067				case ai.ContentTypeText:
1068					textPart, ok := ai.AsContentType[ai.TextPart](c)
1069					if !ok {
1070						warnings = append(warnings, ai.CallWarning{
1071							Type:    ai.CallWarningTypeOther,
1072							Message: "assistant message text part does not have the right type",
1073						})
1074						continue
1075					}
1076					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1077						OfString: param.NewOpt(textPart.Text),
1078					}
1079				case ai.ContentTypeToolCall:
1080					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1081					if !ok {
1082						warnings = append(warnings, ai.CallWarning{
1083							Type:    ai.CallWarningTypeOther,
1084							Message: "assistant message tool part does not have the right type",
1085						})
1086						continue
1087					}
1088					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1089						openai.ChatCompletionMessageToolCallUnionParam{
1090							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1091								ID:   toolCallPart.ToolCallID,
1092								Type: "function",
1093								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1094									Name:      toolCallPart.ToolName,
1095									Arguments: toolCallPart.Input,
1096								},
1097							},
1098						})
1099				}
1100			}
1101			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1102				OfAssistant: &assistantMsg,
1103			})
1104		case ai.MessageRoleTool:
1105			for _, c := range msg.Content {
1106				if c.GetType() != ai.ContentTypeToolResult {
1107					warnings = append(warnings, ai.CallWarning{
1108						Type:    ai.CallWarningTypeOther,
1109						Message: "tool message can only have tool result content",
1110					})
1111					continue
1112				}
1113
1114				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1115				if !ok {
1116					warnings = append(warnings, ai.CallWarning{
1117						Type:    ai.CallWarningTypeOther,
1118						Message: "tool message result part does not have the right type",
1119					})
1120					continue
1121				}
1122
1123				switch toolResultPart.Output.GetType() {
1124				case ai.ToolResultContentTypeText:
1125					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1126					if !ok {
1127						warnings = append(warnings, ai.CallWarning{
1128							Type:    ai.CallWarningTypeOther,
1129							Message: "tool result output does not have the right type",
1130						})
1131						continue
1132					}
1133					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1134				case ai.ToolResultContentTypeError:
1135					// TODO: check if better handling is needed
1136					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1137					if !ok {
1138						warnings = append(warnings, ai.CallWarning{
1139							Type:    ai.CallWarningTypeOther,
1140							Message: "tool result output does not have the right type",
1141						})
1142						continue
1143					}
1144					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1145				}
1146			}
1147		}
1148	}
1149	return messages, warnings
1150}
1151
1152// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1153func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1154	var annotations []openai.ChatCompletionMessageAnnotation
1155
1156	// Parse the raw JSON to extract annotations
1157	var deltaData map[string]any
1158	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1159		return annotations
1160	}
1161
1162	// Check if annotations exist in the delta
1163	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1164		for _, annotationData := range annotationsData {
1165			if annotationMap, ok := annotationData.(map[string]any); ok {
1166				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1167					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1168						annotation := openai.ChatCompletionMessageAnnotation{
1169							Type: "url_citation",
1170							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1171								URL:   urlCitationData["url"].(string),
1172								Title: urlCitationData["title"].(string),
1173							},
1174						}
1175						annotations = append(annotations, annotation)
1176					}
1177				}
1178			}
1179		}
1180	}
1181
1182	return annotations
1183}