openai.go

   1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/fantasy/ai"
  15	xjson "github.com/charmbracelet/x/json"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23const (
  24	Name       = "openai"
  25	DefaultURL = "https://api.openai.com/v1"
  26)
  27
  28type provider struct {
  29	options options
  30}
  31
  32type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
  33
  34type Hooks struct {
  35	PrepareCallWithOptions PrepareCallWithOptions
  36}
  37
  38type options struct {
  39	baseURL      string
  40	apiKey       string
  41	organization string
  42	project      string
  43	name         string
  44	hooks        Hooks
  45	headers      map[string]string
  46	client       option.HTTPClient
  47}
  48
  49type Option = func(*options)
  50
  51func New(opts ...Option) ai.Provider {
  52	providerOptions := options{
  53		headers: map[string]string{},
  54	}
  55	for _, o := range opts {
  56		o(&providerOptions)
  57	}
  58
  59	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
  60	providerOptions.name = cmp.Or(providerOptions.name, Name)
  61
  62	if providerOptions.organization != "" {
  63		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
  64	}
  65	if providerOptions.project != "" {
  66		providerOptions.headers["OpenAi-Project"] = providerOptions.project
  67	}
  68
  69	return &provider{options: providerOptions}
  70}
  71
  72func WithBaseURL(baseURL string) Option {
  73	return func(o *options) {
  74		o.baseURL = baseURL
  75	}
  76}
  77
  78func WithAPIKey(apiKey string) Option {
  79	return func(o *options) {
  80		o.apiKey = apiKey
  81	}
  82}
  83
  84func WithOrganization(organization string) Option {
  85	return func(o *options) {
  86		o.organization = organization
  87	}
  88}
  89
  90func WithProject(project string) Option {
  91	return func(o *options) {
  92		o.project = project
  93	}
  94}
  95
  96func WithName(name string) Option {
  97	return func(o *options) {
  98		o.name = name
  99	}
 100}
 101
 102func WithHeaders(headers map[string]string) Option {
 103	return func(o *options) {
 104		maps.Copy(o.headers, headers)
 105	}
 106}
 107
 108func WithHTTPClient(client option.HTTPClient) Option {
 109	return func(o *options) {
 110		o.client = client
 111	}
 112}
 113
 114func WithHooks(hooks Hooks) Option {
 115	return func(o *options) {
 116		o.hooks = hooks
 117	}
 118}
 119
 120// LanguageModel implements ai.Provider.
 121func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 122	openaiClientOptions := []option.RequestOption{}
 123	if o.options.apiKey != "" {
 124		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 125	}
 126	if o.options.baseURL != "" {
 127		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 128	}
 129
 130	for key, value := range o.options.headers {
 131		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 132	}
 133
 134	if o.options.client != nil {
 135		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 136	}
 137
 138	return languageModel{
 139		modelID:  modelID,
 140		provider: o.options.name,
 141		options:  o.options,
 142		client:   openai.NewClient(openaiClientOptions...),
 143	}, nil
 144}
 145
 146type languageModel struct {
 147	provider string
 148	modelID  string
 149	client   openai.Client
 150	options  options
 151}
 152
 153// Model implements ai.LanguageModel.
 154func (o languageModel) Model() string {
 155	return o.modelID
 156}
 157
 158// Provider implements ai.LanguageModel.
 159func (o languageModel) Provider() string {
 160	return o.provider
 161}
 162
 163func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 164	if call.ProviderOptions == nil {
 165		return nil, nil
 166	}
 167	var warnings []ai.CallWarning
 168	providerOptions := &ProviderOptions{}
 169	if v, ok := call.ProviderOptions[Name]; ok {
 170		providerOptions, ok = v.(*ProviderOptions)
 171		if !ok {
 172			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 173		}
 174	}
 175
 176	if providerOptions.LogitBias != nil {
 177		params.LogitBias = providerOptions.LogitBias
 178	}
 179	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 180		providerOptions.LogProbs = nil
 181	}
 182	if providerOptions.LogProbs != nil {
 183		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 184	}
 185	if providerOptions.TopLogProbs != nil {
 186		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 187	}
 188	if providerOptions.User != nil {
 189		params.User = param.NewOpt(*providerOptions.User)
 190	}
 191	if providerOptions.ParallelToolCalls != nil {
 192		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 193	}
 194	if providerOptions.MaxCompletionTokens != nil {
 195		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 196	}
 197
 198	if providerOptions.TextVerbosity != nil {
 199		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 200	}
 201	if providerOptions.Prediction != nil {
 202		// Convert map[string]any to ChatCompletionPredictionContentParam
 203		if content, ok := providerOptions.Prediction["content"]; ok {
 204			if contentStr, ok := content.(string); ok {
 205				params.Prediction = openai.ChatCompletionPredictionContentParam{
 206					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 207						OfString: param.NewOpt(contentStr),
 208					},
 209				}
 210			}
 211		}
 212	}
 213	if providerOptions.Store != nil {
 214		params.Store = param.NewOpt(*providerOptions.Store)
 215	}
 216	if providerOptions.Metadata != nil {
 217		// Convert map[string]any to map[string]string
 218		metadata := make(map[string]string)
 219		for k, v := range providerOptions.Metadata {
 220			if str, ok := v.(string); ok {
 221				metadata[k] = str
 222			}
 223		}
 224		params.Metadata = metadata
 225	}
 226	if providerOptions.PromptCacheKey != nil {
 227		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 228	}
 229	if providerOptions.SafetyIdentifier != nil {
 230		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 231	}
 232	if providerOptions.ServiceTier != nil {
 233		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 234	}
 235
 236	if providerOptions.ReasoningEffort != nil {
 237		switch *providerOptions.ReasoningEffort {
 238		case ReasoningEffortMinimal:
 239			params.ReasoningEffort = shared.ReasoningEffortMinimal
 240		case ReasoningEffortLow:
 241			params.ReasoningEffort = shared.ReasoningEffortLow
 242		case ReasoningEffortMedium:
 243			params.ReasoningEffort = shared.ReasoningEffortMedium
 244		case ReasoningEffortHigh:
 245			params.ReasoningEffort = shared.ReasoningEffortHigh
 246		default:
 247			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 248		}
 249	}
 250
 251	if isReasoningModel(model.Model()) {
 252		if providerOptions.LogitBias != nil {
 253			params.LogitBias = nil
 254			warnings = append(warnings, ai.CallWarning{
 255				Type:    ai.CallWarningTypeUnsupportedSetting,
 256				Setting: "LogitBias",
 257				Message: "LogitBias is not supported for reasoning models",
 258			})
 259		}
 260		if providerOptions.LogProbs != nil {
 261			params.Logprobs = param.Opt[bool]{}
 262			warnings = append(warnings, ai.CallWarning{
 263				Type:    ai.CallWarningTypeUnsupportedSetting,
 264				Setting: "Logprobs",
 265				Message: "Logprobs is not supported for reasoning models",
 266			})
 267		}
 268		if providerOptions.TopLogProbs != nil {
 269			params.TopLogprobs = param.Opt[int64]{}
 270			warnings = append(warnings, ai.CallWarning{
 271				Type:    ai.CallWarningTypeUnsupportedSetting,
 272				Setting: "TopLogprobs",
 273				Message: "TopLogprobs is not supported for reasoning models",
 274			})
 275		}
 276
 277		// reasoning models use max_completion_tokens instead of max_tokens
 278		if call.MaxOutputTokens != nil {
 279			if providerOptions.MaxCompletionTokens == nil {
 280				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 281			}
 282			params.MaxTokens = param.Opt[int64]{}
 283		}
 284
 285	}
 286
 287	// Handle service tier validation
 288	if providerOptions.ServiceTier != nil {
 289		serviceTier := *providerOptions.ServiceTier
 290		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
 291			params.ServiceTier = ""
 292			warnings = append(warnings, ai.CallWarning{
 293				Type:    ai.CallWarningTypeUnsupportedSetting,
 294				Setting: "ServiceTier",
 295				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 296			})
 297		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
 298			params.ServiceTier = ""
 299			warnings = append(warnings, ai.CallWarning{
 300				Type:    ai.CallWarningTypeUnsupportedSetting,
 301				Setting: "ServiceTier",
 302				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 303			})
 304		}
 305	}
 306	return warnings, nil
 307}
 308
 309func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 310	params := &openai.ChatCompletionNewParams{}
 311	messages, warnings := toPrompt(call.Prompt)
 312	if call.TopK != nil {
 313		warnings = append(warnings, ai.CallWarning{
 314			Type:    ai.CallWarningTypeUnsupportedSetting,
 315			Setting: "top_k",
 316		})
 317	}
 318	params.Messages = messages
 319	params.Model = o.modelID
 320
 321	if call.MaxOutputTokens != nil {
 322		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 323	}
 324	if call.Temperature != nil {
 325		params.Temperature = param.NewOpt(*call.Temperature)
 326	}
 327	if call.TopP != nil {
 328		params.TopP = param.NewOpt(*call.TopP)
 329	}
 330	if call.FrequencyPenalty != nil {
 331		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 332	}
 333	if call.PresencePenalty != nil {
 334		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 335	}
 336
 337	if isReasoningModel(o.modelID) {
 338		// remove unsupported settings for reasoning models
 339		// see https://platform.openai.com/docs/guides/reasoning#limitations
 340		if call.Temperature != nil {
 341			params.Temperature = param.Opt[float64]{}
 342			warnings = append(warnings, ai.CallWarning{
 343				Type:    ai.CallWarningTypeUnsupportedSetting,
 344				Setting: "temperature",
 345				Details: "temperature is not supported for reasoning models",
 346			})
 347		}
 348		if call.TopP != nil {
 349			params.TopP = param.Opt[float64]{}
 350			warnings = append(warnings, ai.CallWarning{
 351				Type:    ai.CallWarningTypeUnsupportedSetting,
 352				Setting: "TopP",
 353				Details: "TopP is not supported for reasoning models",
 354			})
 355		}
 356		if call.FrequencyPenalty != nil {
 357			params.FrequencyPenalty = param.Opt[float64]{}
 358			warnings = append(warnings, ai.CallWarning{
 359				Type:    ai.CallWarningTypeUnsupportedSetting,
 360				Setting: "FrequencyPenalty",
 361				Details: "FrequencyPenalty is not supported for reasoning models",
 362			})
 363		}
 364		if call.PresencePenalty != nil {
 365			params.PresencePenalty = param.Opt[float64]{}
 366			warnings = append(warnings, ai.CallWarning{
 367				Type:    ai.CallWarningTypeUnsupportedSetting,
 368				Setting: "PresencePenalty",
 369				Details: "PresencePenalty is not supported for reasoning models",
 370			})
 371		}
 372	}
 373
 374	// Handle search preview models
 375	if isSearchPreviewModel(o.modelID) {
 376		if call.Temperature != nil {
 377			params.Temperature = param.Opt[float64]{}
 378			warnings = append(warnings, ai.CallWarning{
 379				Type:    ai.CallWarningTypeUnsupportedSetting,
 380				Setting: "temperature",
 381				Details: "temperature is not supported for the search preview models and has been removed.",
 382			})
 383		}
 384	}
 385
 386	prepareOptions := prepareCallWithOptions
 387	if o.options.hooks.PrepareCallWithOptions != nil {
 388		prepareOptions = o.options.hooks.PrepareCallWithOptions
 389	}
 390
 391	optionsWarnings, err := prepareOptions(o, params, call)
 392	if err != nil {
 393		return nil, nil, err
 394	}
 395
 396	if len(optionsWarnings) > 0 {
 397		warnings = append(warnings, optionsWarnings...)
 398	}
 399
 400	if len(call.Tools) > 0 {
 401		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 402		params.Tools = tools
 403		if toolChoice != nil {
 404			params.ToolChoice = *toolChoice
 405		}
 406		warnings = append(warnings, toolWarnings...)
 407	}
 408	return params, warnings, nil
 409}
 410
 411func (o languageModel) handleError(err error) error {
 412	var apiErr *openai.Error
 413	if errors.As(err, &apiErr) {
 414		requestDump := apiErr.DumpRequest(true)
 415		responseDump := apiErr.DumpResponse(true)
 416		headers := map[string]string{}
 417		for k, h := range apiErr.Response.Header {
 418			v := h[len(h)-1]
 419			headers[strings.ToLower(k)] = v
 420		}
 421		return ai.NewAPICallError(
 422			apiErr.Message,
 423			apiErr.Request.URL.String(),
 424			string(requestDump),
 425			apiErr.StatusCode,
 426			headers,
 427			string(responseDump),
 428			apiErr,
 429			false,
 430		)
 431	}
 432	return err
 433}
 434
 435// Generate implements ai.LanguageModel.
 436func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 437	params, warnings, err := o.prepareParams(call)
 438	if err != nil {
 439		return nil, err
 440	}
 441	response, err := o.client.Chat.Completions.New(ctx, *params)
 442	if err != nil {
 443		return nil, o.handleError(err)
 444	}
 445
 446	if len(response.Choices) == 0 {
 447		return nil, errors.New("no response generated")
 448	}
 449	choice := response.Choices[0]
 450	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 451	text := choice.Message.Content
 452	if text != "" {
 453		content = append(content, ai.TextContent{
 454			Text: text,
 455		})
 456	}
 457
 458	for _, tc := range choice.Message.ToolCalls {
 459		toolCallID := tc.ID
 460		if toolCallID == "" {
 461			toolCallID = uuid.NewString()
 462		}
 463		content = append(content, ai.ToolCallContent{
 464			ProviderExecuted: false, // TODO: update when handling other tools
 465			ToolCallID:       toolCallID,
 466			ToolName:         tc.Function.Name,
 467			Input:            tc.Function.Arguments,
 468		})
 469	}
 470	// Handle annotations/citations
 471	for _, annotation := range choice.Message.Annotations {
 472		if annotation.Type == "url_citation" {
 473			content = append(content, ai.SourceContent{
 474				SourceType: ai.SourceTypeURL,
 475				ID:         uuid.NewString(),
 476				URL:        annotation.URLCitation.URL,
 477				Title:      annotation.URLCitation.Title,
 478			})
 479		}
 480	}
 481
 482	completionTokenDetails := response.Usage.CompletionTokensDetails
 483	promptTokenDetails := response.Usage.PromptTokensDetails
 484
 485	// Build provider metadata
 486	providerMetadata := &ProviderMetadata{}
 487	// Add logprobs if available
 488	if len(choice.Logprobs.Content) > 0 {
 489		providerMetadata.Logprobs = choice.Logprobs.Content
 490	}
 491
 492	// Add prediction tokens if available
 493	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 494		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 495			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 496		}
 497		if completionTokenDetails.RejectedPredictionTokens > 0 {
 498			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 499		}
 500	}
 501
 502	return &ai.Response{
 503		Content: content,
 504		Usage: ai.Usage{
 505			InputTokens:     response.Usage.PromptTokens,
 506			OutputTokens:    response.Usage.CompletionTokens,
 507			TotalTokens:     response.Usage.TotalTokens,
 508			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 509			CacheReadTokens: promptTokenDetails.CachedTokens,
 510		},
 511		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
 512		ProviderMetadata: ai.ProviderMetadata{
 513			Name: providerMetadata,
 514		},
 515		Warnings: warnings,
 516	}, nil
 517}
 518
 519type toolCall struct {
 520	id          string
 521	name        string
 522	arguments   string
 523	hasFinished bool
 524}
 525
 526// Stream implements ai.LanguageModel.
 527func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 528	params, warnings, err := o.prepareParams(call)
 529	if err != nil {
 530		return nil, err
 531	}
 532
 533	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 534		IncludeUsage: openai.Bool(true),
 535	}
 536
 537	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 538	isActiveText := false
 539	toolCalls := make(map[int64]toolCall)
 540
 541	// Build provider metadata for streaming
 542	streamProviderMetadata := &ProviderMetadata{}
 543	acc := openai.ChatCompletionAccumulator{}
 544	var usage ai.Usage
 545	return func(yield func(ai.StreamPart) bool) {
 546		if len(warnings) > 0 {
 547			if !yield(ai.StreamPart{
 548				Type:     ai.StreamPartTypeWarnings,
 549				Warnings: warnings,
 550			}) {
 551				return
 552			}
 553		}
 554		for stream.Next() {
 555			chunk := stream.Current()
 556			acc.AddChunk(chunk)
 557			if chunk.Usage.TotalTokens > 0 {
 558				// we do this here because the acc does not add prompt details
 559				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 560				promptTokenDetails := chunk.Usage.PromptTokensDetails
 561				usage = ai.Usage{
 562					InputTokens:     chunk.Usage.PromptTokens,
 563					OutputTokens:    chunk.Usage.CompletionTokens,
 564					TotalTokens:     chunk.Usage.TotalTokens,
 565					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 566					CacheReadTokens: promptTokenDetails.CachedTokens,
 567				}
 568
 569				// Add prediction tokens if available
 570				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 571					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 572						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 573					}
 574					if completionTokenDetails.RejectedPredictionTokens > 0 {
 575						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 576					}
 577				}
 578			}
 579			if len(chunk.Choices) == 0 {
 580				continue
 581			}
 582			for _, choice := range chunk.Choices {
 583				switch {
 584				case choice.Delta.Content != "":
 585					if !isActiveText {
 586						isActiveText = true
 587						if !yield(ai.StreamPart{
 588							Type: ai.StreamPartTypeTextStart,
 589							ID:   "0",
 590						}) {
 591							return
 592						}
 593					}
 594					if !yield(ai.StreamPart{
 595						Type:  ai.StreamPartTypeTextDelta,
 596						ID:    "0",
 597						Delta: choice.Delta.Content,
 598					}) {
 599						return
 600					}
 601				case len(choice.Delta.ToolCalls) > 0:
 602					if isActiveText {
 603						isActiveText = false
 604						if !yield(ai.StreamPart{
 605							Type: ai.StreamPartTypeTextEnd,
 606							ID:   "0",
 607						}) {
 608							return
 609						}
 610					}
 611
 612					for _, toolCallDelta := range choice.Delta.ToolCalls {
 613						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 614							if existingToolCall.hasFinished {
 615								continue
 616							}
 617							if toolCallDelta.Function.Arguments != "" {
 618								existingToolCall.arguments += toolCallDelta.Function.Arguments
 619							}
 620							if !yield(ai.StreamPart{
 621								Type:  ai.StreamPartTypeToolInputDelta,
 622								ID:    existingToolCall.id,
 623								Delta: toolCallDelta.Function.Arguments,
 624							}) {
 625								return
 626							}
 627							toolCalls[toolCallDelta.Index] = existingToolCall
 628							if xjson.IsValid(existingToolCall.arguments) {
 629								if !yield(ai.StreamPart{
 630									Type: ai.StreamPartTypeToolInputEnd,
 631									ID:   existingToolCall.id,
 632								}) {
 633									return
 634								}
 635
 636								if !yield(ai.StreamPart{
 637									Type:          ai.StreamPartTypeToolCall,
 638									ID:            existingToolCall.id,
 639									ToolCallName:  existingToolCall.name,
 640									ToolCallInput: existingToolCall.arguments,
 641								}) {
 642									return
 643								}
 644								existingToolCall.hasFinished = true
 645								toolCalls[toolCallDelta.Index] = existingToolCall
 646							}
 647						} else {
 648							// Does not exist
 649							var err error
 650							if toolCallDelta.Type != "function" {
 651								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 652							}
 653							if toolCallDelta.ID == "" {
 654								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 655							}
 656							if toolCallDelta.Function.Name == "" {
 657								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 658							}
 659							if err != nil {
 660								yield(ai.StreamPart{
 661									Type:  ai.StreamPartTypeError,
 662									Error: o.handleError(stream.Err()),
 663								})
 664								return
 665							}
 666
 667							if !yield(ai.StreamPart{
 668								Type:         ai.StreamPartTypeToolInputStart,
 669								ID:           toolCallDelta.ID,
 670								ToolCallName: toolCallDelta.Function.Name,
 671							}) {
 672								return
 673							}
 674							toolCalls[toolCallDelta.Index] = toolCall{
 675								id:        toolCallDelta.ID,
 676								name:      toolCallDelta.Function.Name,
 677								arguments: toolCallDelta.Function.Arguments,
 678							}
 679
 680							exTc := toolCalls[toolCallDelta.Index]
 681							if exTc.arguments != "" {
 682								if !yield(ai.StreamPart{
 683									Type:  ai.StreamPartTypeToolInputDelta,
 684									ID:    exTc.id,
 685									Delta: exTc.arguments,
 686								}) {
 687									return
 688								}
 689								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
 690									if !yield(ai.StreamPart{
 691										Type: ai.StreamPartTypeToolInputEnd,
 692										ID:   toolCallDelta.ID,
 693									}) {
 694										return
 695									}
 696
 697									if !yield(ai.StreamPart{
 698										Type:          ai.StreamPartTypeToolCall,
 699										ID:            exTc.id,
 700										ToolCallName:  exTc.name,
 701										ToolCallInput: exTc.arguments,
 702									}) {
 703										return
 704									}
 705									exTc.hasFinished = true
 706									toolCalls[toolCallDelta.Index] = exTc
 707								}
 708							}
 709							continue
 710						}
 711					}
 712				}
 713			}
 714
 715			// Check for annotations in the delta's raw JSON
 716			for _, choice := range chunk.Choices {
 717				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 718					for _, annotation := range annotations {
 719						if annotation.Type == "url_citation" {
 720							if !yield(ai.StreamPart{
 721								Type:       ai.StreamPartTypeSource,
 722								ID:         uuid.NewString(),
 723								SourceType: ai.SourceTypeURL,
 724								URL:        annotation.URLCitation.URL,
 725								Title:      annotation.URLCitation.Title,
 726							}) {
 727								return
 728							}
 729						}
 730					}
 731				}
 732			}
 733		}
 734		err := stream.Err()
 735		if err == nil || errors.Is(err, io.EOF) {
 736			// finished
 737			if isActiveText {
 738				isActiveText = false
 739				if !yield(ai.StreamPart{
 740					Type: ai.StreamPartTypeTextEnd,
 741					ID:   "0",
 742				}) {
 743					return
 744				}
 745			}
 746
 747			// Add logprobs if available
 748			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 749				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
 750			}
 751
 752			// Handle annotations/citations from accumulated response
 753			if len(acc.Choices) > 0 {
 754				for _, annotation := range acc.Choices[0].Message.Annotations {
 755					if annotation.Type == "url_citation" {
 756						if !yield(ai.StreamPart{
 757							Type:       ai.StreamPartTypeSource,
 758							ID:         acc.ID,
 759							SourceType: ai.SourceTypeURL,
 760							URL:        annotation.URLCitation.URL,
 761							Title:      annotation.URLCitation.Title,
 762						}) {
 763							return
 764						}
 765					}
 766				}
 767			}
 768
 769			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 770			yield(ai.StreamPart{
 771				Type:         ai.StreamPartTypeFinish,
 772				Usage:        usage,
 773				FinishReason: finishReason,
 774				ProviderMetadata: ai.ProviderMetadata{
 775					Name: streamProviderMetadata,
 776				},
 777			})
 778			return
 779		} else {
 780			yield(ai.StreamPart{
 781				Type:  ai.StreamPartTypeError,
 782				Error: o.handleError(err),
 783			})
 784			return
 785		}
 786	}, nil
 787}
 788
 789func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
 790	var options ProviderOptions
 791	if err := ai.ParseOptions(data, &options); err != nil {
 792		return nil, err
 793	}
 794	return &options, nil
 795}
 796
 797func (o *provider) Name() string {
 798	return Name
 799}
 800
 801func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 802	switch finishReason {
 803	case "stop":
 804		return ai.FinishReasonStop
 805	case "length":
 806		return ai.FinishReasonLength
 807	case "content_filter":
 808		return ai.FinishReasonContentFilter
 809	case "function_call", "tool_calls":
 810		return ai.FinishReasonToolCalls
 811	default:
 812		return ai.FinishReasonUnknown
 813	}
 814}
 815
 816func isReasoningModel(modelID string) bool {
 817	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 818}
 819
 820func isSearchPreviewModel(modelID string) bool {
 821	return strings.Contains(modelID, "search-preview")
 822}
 823
 824func supportsFlexProcessing(modelID string) bool {
 825	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 826}
 827
 828func supportsPriorityProcessing(modelID string) bool {
 829	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 830		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 831		strings.HasPrefix(modelID, "o4-mini")
 832}
 833
 834func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 835	for _, tool := range tools {
 836		if tool.GetType() == ai.ToolTypeFunction {
 837			ft, ok := tool.(ai.FunctionTool)
 838			if !ok {
 839				continue
 840			}
 841			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 842				OfFunction: &openai.ChatCompletionFunctionToolParam{
 843					Function: shared.FunctionDefinitionParam{
 844						Name:        ft.Name,
 845						Description: param.NewOpt(ft.Description),
 846						Parameters:  openai.FunctionParameters(ft.InputSchema),
 847						Strict:      param.NewOpt(false),
 848					},
 849					Type: "function",
 850				},
 851			})
 852			continue
 853		}
 854
 855		// TODO: handle provider tool calls
 856		warnings = append(warnings, ai.CallWarning{
 857			Type:    ai.CallWarningTypeUnsupportedTool,
 858			Tool:    tool,
 859			Message: "tool is not supported",
 860		})
 861	}
 862	if toolChoice == nil {
 863		return openAiTools, openAiToolChoice, warnings
 864	}
 865
 866	switch *toolChoice {
 867	case ai.ToolChoiceAuto:
 868		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 869			OfAuto: param.NewOpt("auto"),
 870		}
 871	case ai.ToolChoiceNone:
 872		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 873			OfAuto: param.NewOpt("none"),
 874		}
 875	default:
 876		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 877			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 878				Type: "function",
 879				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 880					Name: string(*toolChoice),
 881				},
 882			},
 883		}
 884	}
 885	return openAiTools, openAiToolChoice, warnings
 886}
 887
 888func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 889	var messages []openai.ChatCompletionMessageParamUnion
 890	var warnings []ai.CallWarning
 891	for _, msg := range prompt {
 892		switch msg.Role {
 893		case ai.MessageRoleSystem:
 894			var systemPromptParts []string
 895			for _, c := range msg.Content {
 896				if c.GetType() != ai.ContentTypeText {
 897					warnings = append(warnings, ai.CallWarning{
 898						Type:    ai.CallWarningTypeOther,
 899						Message: "system prompt can only have text content",
 900					})
 901					continue
 902				}
 903				textPart, ok := ai.AsContentType[ai.TextPart](c)
 904				if !ok {
 905					warnings = append(warnings, ai.CallWarning{
 906						Type:    ai.CallWarningTypeOther,
 907						Message: "system prompt text part does not have the right type",
 908					})
 909					continue
 910				}
 911				text := textPart.Text
 912				if strings.TrimSpace(text) != "" {
 913					systemPromptParts = append(systemPromptParts, textPart.Text)
 914				}
 915			}
 916			if len(systemPromptParts) == 0 {
 917				warnings = append(warnings, ai.CallWarning{
 918					Type:    ai.CallWarningTypeOther,
 919					Message: "system prompt has no text parts",
 920				})
 921				continue
 922			}
 923			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 924		case ai.MessageRoleUser:
 925			// simple user message just text content
 926			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 927				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 928				if !ok {
 929					warnings = append(warnings, ai.CallWarning{
 930						Type:    ai.CallWarningTypeOther,
 931						Message: "user message text part does not have the right type",
 932					})
 933					continue
 934				}
 935				messages = append(messages, openai.UserMessage(textPart.Text))
 936				continue
 937			}
 938			// text content and attachments
 939			// for now we only support image content later we need to check
 940			// TODO: add the supported media types to the language model so we
 941			//  can use that to validate the data here.
 942			var content []openai.ChatCompletionContentPartUnionParam
 943			for _, c := range msg.Content {
 944				switch c.GetType() {
 945				case ai.ContentTypeText:
 946					textPart, ok := ai.AsContentType[ai.TextPart](c)
 947					if !ok {
 948						warnings = append(warnings, ai.CallWarning{
 949							Type:    ai.CallWarningTypeOther,
 950							Message: "user message text part does not have the right type",
 951						})
 952						continue
 953					}
 954					content = append(content, openai.ChatCompletionContentPartUnionParam{
 955						OfText: &openai.ChatCompletionContentPartTextParam{
 956							Text: textPart.Text,
 957						},
 958					})
 959				case ai.ContentTypeFile:
 960					filePart, ok := ai.AsContentType[ai.FilePart](c)
 961					if !ok {
 962						warnings = append(warnings, ai.CallWarning{
 963							Type:    ai.CallWarningTypeOther,
 964							Message: "user message file part does not have the right type",
 965						})
 966						continue
 967					}
 968
 969					switch {
 970					case strings.HasPrefix(filePart.MediaType, "image/"):
 971						// Handle image files
 972						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 973						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 974						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 975
 976						// Check for provider-specific options like image detail
 977						if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
 978							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
 979								imageURL.Detail = detail.ImageDetail
 980							}
 981						}
 982
 983						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 984						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 985
 986					case filePart.MediaType == "audio/wav":
 987						// Handle WAV audio files
 988						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 989						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 990							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 991								Data:   base64Encoded,
 992								Format: "wav",
 993							},
 994						}
 995						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 996
 997					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 998						// Handle MP3 audio files
 999						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1000						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
1001							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
1002								Data:   base64Encoded,
1003								Format: "mp3",
1004							},
1005						}
1006						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
1007
1008					case filePart.MediaType == "application/pdf":
1009						// Handle PDF files
1010						dataStr := string(filePart.Data)
1011
1012						// Check if data looks like a file ID (starts with "file-")
1013						if strings.HasPrefix(dataStr, "file-") {
1014							fileBlock := openai.ChatCompletionContentPartFileParam{
1015								File: openai.ChatCompletionContentPartFileFileParam{
1016									FileID: param.NewOpt(dataStr),
1017								},
1018							}
1019							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1020						} else {
1021							// Handle as base64 data
1022							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1023							data := "data:application/pdf;base64," + base64Encoded
1024
1025							filename := filePart.Filename
1026							if filename == "" {
1027								// Generate default filename based on content index
1028								filename = fmt.Sprintf("part-%d.pdf", len(content))
1029							}
1030
1031							fileBlock := openai.ChatCompletionContentPartFileParam{
1032								File: openai.ChatCompletionContentPartFileFileParam{
1033									Filename: param.NewOpt(filename),
1034									FileData: param.NewOpt(data),
1035								},
1036							}
1037							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1038						}
1039
1040					default:
1041						warnings = append(warnings, ai.CallWarning{
1042							Type:    ai.CallWarningTypeOther,
1043							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1044						})
1045					}
1046				}
1047			}
1048			messages = append(messages, openai.UserMessage(content))
1049		case ai.MessageRoleAssistant:
1050			// simple assistant message just text content
1051			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1052				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1053				if !ok {
1054					warnings = append(warnings, ai.CallWarning{
1055						Type:    ai.CallWarningTypeOther,
1056						Message: "assistant message text part does not have the right type",
1057					})
1058					continue
1059				}
1060				messages = append(messages, openai.AssistantMessage(textPart.Text))
1061				continue
1062			}
1063			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1064				Role: "assistant",
1065			}
1066			for _, c := range msg.Content {
1067				switch c.GetType() {
1068				case ai.ContentTypeText:
1069					textPart, ok := ai.AsContentType[ai.TextPart](c)
1070					if !ok {
1071						warnings = append(warnings, ai.CallWarning{
1072							Type:    ai.CallWarningTypeOther,
1073							Message: "assistant message text part does not have the right type",
1074						})
1075						continue
1076					}
1077					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1078						OfString: param.NewOpt(textPart.Text),
1079					}
1080				case ai.ContentTypeToolCall:
1081					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1082					if !ok {
1083						warnings = append(warnings, ai.CallWarning{
1084							Type:    ai.CallWarningTypeOther,
1085							Message: "assistant message tool part does not have the right type",
1086						})
1087						continue
1088					}
1089					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1090						openai.ChatCompletionMessageToolCallUnionParam{
1091							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1092								ID:   toolCallPart.ToolCallID,
1093								Type: "function",
1094								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1095									Name:      toolCallPart.ToolName,
1096									Arguments: toolCallPart.Input,
1097								},
1098							},
1099						})
1100				}
1101			}
1102			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1103				OfAssistant: &assistantMsg,
1104			})
1105		case ai.MessageRoleTool:
1106			for _, c := range msg.Content {
1107				if c.GetType() != ai.ContentTypeToolResult {
1108					warnings = append(warnings, ai.CallWarning{
1109						Type:    ai.CallWarningTypeOther,
1110						Message: "tool message can only have tool result content",
1111					})
1112					continue
1113				}
1114
1115				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1116				if !ok {
1117					warnings = append(warnings, ai.CallWarning{
1118						Type:    ai.CallWarningTypeOther,
1119						Message: "tool message result part does not have the right type",
1120					})
1121					continue
1122				}
1123
1124				switch toolResultPart.Output.GetType() {
1125				case ai.ToolResultContentTypeText:
1126					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1127					if !ok {
1128						warnings = append(warnings, ai.CallWarning{
1129							Type:    ai.CallWarningTypeOther,
1130							Message: "tool result output does not have the right type",
1131						})
1132						continue
1133					}
1134					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1135				case ai.ToolResultContentTypeError:
1136					// TODO: check if better handling is needed
1137					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1138					if !ok {
1139						warnings = append(warnings, ai.CallWarning{
1140							Type:    ai.CallWarningTypeOther,
1141							Message: "tool result output does not have the right type",
1142						})
1143						continue
1144					}
1145					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1146				}
1147			}
1148		}
1149	}
1150	return messages, warnings
1151}
1152
1153// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1154func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1155	var annotations []openai.ChatCompletionMessageAnnotation
1156
1157	// Parse the raw JSON to extract annotations
1158	var deltaData map[string]any
1159	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1160		return annotations
1161	}
1162
1163	// Check if annotations exist in the delta
1164	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1165		for _, annotationData := range annotationsData {
1166			if annotationMap, ok := annotationData.(map[string]any); ok {
1167				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1168					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1169						annotation := openai.ChatCompletionMessageAnnotation{
1170							Type: "url_citation",
1171							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1172								URL:   urlCitationData["url"].(string),
1173								Title: urlCitationData["title"].(string),
1174							},
1175						}
1176						annotations = append(annotations, annotation)
1177					}
1178				}
1179			}
1180		}
1181	}
1182
1183	return annotations
1184}