1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/fantasy/ai"
  15	xjson "github.com/charmbracelet/x/json"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23const (
  24	Name       = "openai"
  25	DefaultURL = "https://api.openai.com/v1"
  26)
  27
  28type provider struct {
  29	options options
  30}
  31
  32type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
  33
  34type Hooks struct {
  35	PrepareCallWithOptions PrepareCallWithOptions
  36}
  37
  38type options struct {
  39	baseURL      string
  40	apiKey       string
  41	organization string
  42	project      string
  43	name         string
  44	hooks        Hooks
  45	headers      map[string]string
  46	client       option.HTTPClient
  47}
  48
  49type Option = func(*options)
  50
  51func New(opts ...Option) ai.Provider {
  52	providerOptions := options{
  53		headers: map[string]string{},
  54	}
  55	for _, o := range opts {
  56		o(&providerOptions)
  57	}
  58
  59	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
  60	providerOptions.name = cmp.Or(providerOptions.name, Name)
  61
  62	if providerOptions.organization != "" {
  63		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
  64	}
  65	if providerOptions.project != "" {
  66		providerOptions.headers["OpenAi-Project"] = providerOptions.project
  67	}
  68
  69	return &provider{options: providerOptions}
  70}
  71
  72func WithBaseURL(baseURL string) Option {
  73	return func(o *options) {
  74		o.baseURL = baseURL
  75	}
  76}
  77
  78func WithAPIKey(apiKey string) Option {
  79	return func(o *options) {
  80		o.apiKey = apiKey
  81	}
  82}
  83
  84func WithOrganization(organization string) Option {
  85	return func(o *options) {
  86		o.organization = organization
  87	}
  88}
  89
  90func WithProject(project string) Option {
  91	return func(o *options) {
  92		o.project = project
  93	}
  94}
  95
  96func WithName(name string) Option {
  97	return func(o *options) {
  98		o.name = name
  99	}
 100}
 101
 102func WithHeaders(headers map[string]string) Option {
 103	return func(o *options) {
 104		maps.Copy(o.headers, headers)
 105	}
 106}
 107
 108func WithHTTPClient(client option.HTTPClient) Option {
 109	return func(o *options) {
 110		o.client = client
 111	}
 112}
 113
 114func WithHooks(hooks Hooks) Option {
 115	return func(o *options) {
 116		o.hooks = hooks
 117	}
 118}
 119
 120// LanguageModel implements ai.Provider.
 121func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 122	openaiClientOptions := []option.RequestOption{}
 123	if o.options.apiKey != "" {
 124		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 125	}
 126	if o.options.baseURL != "" {
 127		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 128	}
 129
 130	for key, value := range o.options.headers {
 131		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 132	}
 133
 134	if o.options.client != nil {
 135		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 136	}
 137
 138	return languageModel{
 139		modelID:  modelID,
 140		provider: o.options.name,
 141		options:  o.options,
 142		client:   openai.NewClient(openaiClientOptions...),
 143	}, nil
 144}
 145
 146type languageModel struct {
 147	provider string
 148	modelID  string
 149	client   openai.Client
 150	options  options
 151}
 152
 153// Model implements ai.LanguageModel.
 154func (o languageModel) Model() string {
 155	return o.modelID
 156}
 157
 158// Provider implements ai.LanguageModel.
 159func (o languageModel) Provider() string {
 160	return o.provider
 161}
 162
 163func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 164	params := &openai.ChatCompletionNewParams{}
 165	messages, warnings := toPrompt(call.Prompt)
 166	if call.TopK != nil {
 167		warnings = append(warnings, ai.CallWarning{
 168			Type:    ai.CallWarningTypeUnsupportedSetting,
 169			Setting: "top_k",
 170		})
 171	}
 172	params.Messages = messages
 173	params.Model = o.modelID
 174
 175	if call.MaxOutputTokens != nil {
 176		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 177	}
 178	if call.Temperature != nil {
 179		params.Temperature = param.NewOpt(*call.Temperature)
 180	}
 181	if call.TopP != nil {
 182		params.TopP = param.NewOpt(*call.TopP)
 183	}
 184	if call.FrequencyPenalty != nil {
 185		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 186	}
 187	if call.PresencePenalty != nil {
 188		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 189	}
 190
 191	if isReasoningModel(o.modelID) {
 192		// remove unsupported settings for reasoning models
 193		// see https://platform.openai.com/docs/guides/reasoning#limitations
 194		if call.Temperature != nil {
 195			params.Temperature = param.Opt[float64]{}
 196			warnings = append(warnings, ai.CallWarning{
 197				Type:    ai.CallWarningTypeUnsupportedSetting,
 198				Setting: "temperature",
 199				Details: "temperature is not supported for reasoning models",
 200			})
 201		}
 202		if call.TopP != nil {
 203			params.TopP = param.Opt[float64]{}
 204			warnings = append(warnings, ai.CallWarning{
 205				Type:    ai.CallWarningTypeUnsupportedSetting,
 206				Setting: "TopP",
 207				Details: "TopP is not supported for reasoning models",
 208			})
 209		}
 210		if call.FrequencyPenalty != nil {
 211			params.FrequencyPenalty = param.Opt[float64]{}
 212			warnings = append(warnings, ai.CallWarning{
 213				Type:    ai.CallWarningTypeUnsupportedSetting,
 214				Setting: "FrequencyPenalty",
 215				Details: "FrequencyPenalty is not supported for reasoning models",
 216			})
 217		}
 218		if call.PresencePenalty != nil {
 219			params.PresencePenalty = param.Opt[float64]{}
 220			warnings = append(warnings, ai.CallWarning{
 221				Type:    ai.CallWarningTypeUnsupportedSetting,
 222				Setting: "PresencePenalty",
 223				Details: "PresencePenalty is not supported for reasoning models",
 224			})
 225		}
 226
 227		// reasoning models use max_completion_tokens instead of max_tokens
 228		if call.MaxOutputTokens != nil {
 229			if !params.MaxCompletionTokens.Valid() {
 230				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 231			}
 232			params.MaxTokens = param.Opt[int64]{}
 233		}
 234	}
 235
 236	// Handle search preview models
 237	if isSearchPreviewModel(o.modelID) {
 238		if call.Temperature != nil {
 239			params.Temperature = param.Opt[float64]{}
 240			warnings = append(warnings, ai.CallWarning{
 241				Type:    ai.CallWarningTypeUnsupportedSetting,
 242				Setting: "temperature",
 243				Details: "temperature is not supported for the search preview models and has been removed.",
 244			})
 245		}
 246	}
 247
 248	prepareOptions := prepareCallWithOptions
 249	if o.options.hooks.PrepareCallWithOptions != nil {
 250		prepareOptions = o.options.hooks.PrepareCallWithOptions
 251	}
 252
 253	optionsWarnings, err := prepareOptions(o, params, call)
 254	if err != nil {
 255		return nil, nil, err
 256	}
 257
 258	if len(optionsWarnings) > 0 {
 259		warnings = append(warnings, optionsWarnings...)
 260	}
 261
 262	if len(call.Tools) > 0 {
 263		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 264		params.Tools = tools
 265		if toolChoice != nil {
 266			params.ToolChoice = *toolChoice
 267		}
 268		warnings = append(warnings, toolWarnings...)
 269	}
 270	return params, warnings, nil
 271}
 272
 273func (o languageModel) handleError(err error) error {
 274	var apiErr *openai.Error
 275	if errors.As(err, &apiErr) {
 276		requestDump := apiErr.DumpRequest(true)
 277		responseDump := apiErr.DumpResponse(true)
 278		headers := map[string]string{}
 279		for k, h := range apiErr.Response.Header {
 280			v := h[len(h)-1]
 281			headers[strings.ToLower(k)] = v
 282		}
 283		return ai.NewAPICallError(
 284			apiErr.Message,
 285			apiErr.Request.URL.String(),
 286			string(requestDump),
 287			apiErr.StatusCode,
 288			headers,
 289			string(responseDump),
 290			apiErr,
 291			false,
 292		)
 293	}
 294	return err
 295}
 296
 297// Generate implements ai.LanguageModel.
 298func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 299	params, warnings, err := o.prepareParams(call)
 300	if err != nil {
 301		return nil, err
 302	}
 303	response, err := o.client.Chat.Completions.New(ctx, *params)
 304	if err != nil {
 305		return nil, o.handleError(err)
 306	}
 307
 308	if len(response.Choices) == 0 {
 309		return nil, errors.New("no response generated")
 310	}
 311	choice := response.Choices[0]
 312	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 313	text := choice.Message.Content
 314	if text != "" {
 315		content = append(content, ai.TextContent{
 316			Text: text,
 317		})
 318	}
 319
 320	for _, tc := range choice.Message.ToolCalls {
 321		toolCallID := tc.ID
 322		if toolCallID == "" {
 323			toolCallID = uuid.NewString()
 324		}
 325		content = append(content, ai.ToolCallContent{
 326			ProviderExecuted: false, // TODO: update when handling other tools
 327			ToolCallID:       toolCallID,
 328			ToolName:         tc.Function.Name,
 329			Input:            tc.Function.Arguments,
 330		})
 331	}
 332	// Handle annotations/citations
 333	for _, annotation := range choice.Message.Annotations {
 334		if annotation.Type == "url_citation" {
 335			content = append(content, ai.SourceContent{
 336				SourceType: ai.SourceTypeURL,
 337				ID:         uuid.NewString(),
 338				URL:        annotation.URLCitation.URL,
 339				Title:      annotation.URLCitation.Title,
 340			})
 341		}
 342	}
 343
 344	completionTokenDetails := response.Usage.CompletionTokensDetails
 345	promptTokenDetails := response.Usage.PromptTokensDetails
 346
 347	// Build provider metadata
 348	providerMetadata := &ProviderMetadata{}
 349	// Add logprobs if available
 350	if len(choice.Logprobs.Content) > 0 {
 351		providerMetadata.Logprobs = choice.Logprobs.Content
 352	}
 353
 354	// Add prediction tokens if available
 355	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 356		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 357			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 358		}
 359		if completionTokenDetails.RejectedPredictionTokens > 0 {
 360			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 361		}
 362	}
 363
 364	return &ai.Response{
 365		Content: content,
 366		Usage: ai.Usage{
 367			InputTokens:     response.Usage.PromptTokens,
 368			OutputTokens:    response.Usage.CompletionTokens,
 369			TotalTokens:     response.Usage.TotalTokens,
 370			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 371			CacheReadTokens: promptTokenDetails.CachedTokens,
 372		},
 373		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
 374		ProviderMetadata: ai.ProviderMetadata{
 375			Name: providerMetadata,
 376		},
 377		Warnings: warnings,
 378	}, nil
 379}
 380
 381type toolCall struct {
 382	id          string
 383	name        string
 384	arguments   string
 385	hasFinished bool
 386}
 387
 388// Stream implements ai.LanguageModel.
 389func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 390	params, warnings, err := o.prepareParams(call)
 391	if err != nil {
 392		return nil, err
 393	}
 394
 395	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 396		IncludeUsage: openai.Bool(true),
 397	}
 398
 399	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 400	isActiveText := false
 401	toolCalls := make(map[int64]toolCall)
 402
 403	// Build provider metadata for streaming
 404	streamProviderMetadata := &ProviderMetadata{}
 405	acc := openai.ChatCompletionAccumulator{}
 406	var usage ai.Usage
 407	return func(yield func(ai.StreamPart) bool) {
 408		if len(warnings) > 0 {
 409			if !yield(ai.StreamPart{
 410				Type:     ai.StreamPartTypeWarnings,
 411				Warnings: warnings,
 412			}) {
 413				return
 414			}
 415		}
 416		for stream.Next() {
 417			chunk := stream.Current()
 418			acc.AddChunk(chunk)
 419			if chunk.Usage.TotalTokens > 0 {
 420				// we do this here because the acc does not add prompt details
 421				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 422				promptTokenDetails := chunk.Usage.PromptTokensDetails
 423				usage = ai.Usage{
 424					InputTokens:     chunk.Usage.PromptTokens,
 425					OutputTokens:    chunk.Usage.CompletionTokens,
 426					TotalTokens:     chunk.Usage.TotalTokens,
 427					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 428					CacheReadTokens: promptTokenDetails.CachedTokens,
 429				}
 430
 431				// Add prediction tokens if available
 432				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 433					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 434						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 435					}
 436					if completionTokenDetails.RejectedPredictionTokens > 0 {
 437						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 438					}
 439				}
 440			}
 441			if len(chunk.Choices) == 0 {
 442				continue
 443			}
 444			for _, choice := range chunk.Choices {
 445				switch {
 446				case choice.Delta.Content != "":
 447					if !isActiveText {
 448						isActiveText = true
 449						if !yield(ai.StreamPart{
 450							Type: ai.StreamPartTypeTextStart,
 451							ID:   "0",
 452						}) {
 453							return
 454						}
 455					}
 456					if !yield(ai.StreamPart{
 457						Type:  ai.StreamPartTypeTextDelta,
 458						ID:    "0",
 459						Delta: choice.Delta.Content,
 460					}) {
 461						return
 462					}
 463				case len(choice.Delta.ToolCalls) > 0:
 464					if isActiveText {
 465						isActiveText = false
 466						if !yield(ai.StreamPart{
 467							Type: ai.StreamPartTypeTextEnd,
 468							ID:   "0",
 469						}) {
 470							return
 471						}
 472					}
 473
 474					for _, toolCallDelta := range choice.Delta.ToolCalls {
 475						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 476							if existingToolCall.hasFinished {
 477								continue
 478							}
 479							if toolCallDelta.Function.Arguments != "" {
 480								existingToolCall.arguments += toolCallDelta.Function.Arguments
 481							}
 482							if !yield(ai.StreamPart{
 483								Type:  ai.StreamPartTypeToolInputDelta,
 484								ID:    existingToolCall.id,
 485								Delta: toolCallDelta.Function.Arguments,
 486							}) {
 487								return
 488							}
 489							toolCalls[toolCallDelta.Index] = existingToolCall
 490							if xjson.IsValid(existingToolCall.arguments) {
 491								if !yield(ai.StreamPart{
 492									Type: ai.StreamPartTypeToolInputEnd,
 493									ID:   existingToolCall.id,
 494								}) {
 495									return
 496								}
 497
 498								if !yield(ai.StreamPart{
 499									Type:          ai.StreamPartTypeToolCall,
 500									ID:            existingToolCall.id,
 501									ToolCallName:  existingToolCall.name,
 502									ToolCallInput: existingToolCall.arguments,
 503								}) {
 504									return
 505								}
 506								existingToolCall.hasFinished = true
 507								toolCalls[toolCallDelta.Index] = existingToolCall
 508							}
 509						} else {
 510							// Does not exist
 511							var err error
 512							if toolCallDelta.Type != "function" {
 513								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 514							}
 515							if toolCallDelta.ID == "" {
 516								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 517							}
 518							if toolCallDelta.Function.Name == "" {
 519								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 520							}
 521							if err != nil {
 522								yield(ai.StreamPart{
 523									Type:  ai.StreamPartTypeError,
 524									Error: o.handleError(stream.Err()),
 525								})
 526								return
 527							}
 528
 529							if !yield(ai.StreamPart{
 530								Type:         ai.StreamPartTypeToolInputStart,
 531								ID:           toolCallDelta.ID,
 532								ToolCallName: toolCallDelta.Function.Name,
 533							}) {
 534								return
 535							}
 536							toolCalls[toolCallDelta.Index] = toolCall{
 537								id:        toolCallDelta.ID,
 538								name:      toolCallDelta.Function.Name,
 539								arguments: toolCallDelta.Function.Arguments,
 540							}
 541
 542							exTc := toolCalls[toolCallDelta.Index]
 543							if exTc.arguments != "" {
 544								if !yield(ai.StreamPart{
 545									Type:  ai.StreamPartTypeToolInputDelta,
 546									ID:    exTc.id,
 547									Delta: exTc.arguments,
 548								}) {
 549									return
 550								}
 551								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
 552									if !yield(ai.StreamPart{
 553										Type: ai.StreamPartTypeToolInputEnd,
 554										ID:   toolCallDelta.ID,
 555									}) {
 556										return
 557									}
 558
 559									if !yield(ai.StreamPart{
 560										Type:          ai.StreamPartTypeToolCall,
 561										ID:            exTc.id,
 562										ToolCallName:  exTc.name,
 563										ToolCallInput: exTc.arguments,
 564									}) {
 565										return
 566									}
 567									exTc.hasFinished = true
 568									toolCalls[toolCallDelta.Index] = exTc
 569								}
 570							}
 571							continue
 572						}
 573					}
 574				}
 575			}
 576
 577			// Check for annotations in the delta's raw JSON
 578			for _, choice := range chunk.Choices {
 579				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 580					for _, annotation := range annotations {
 581						if annotation.Type == "url_citation" {
 582							if !yield(ai.StreamPart{
 583								Type:       ai.StreamPartTypeSource,
 584								ID:         uuid.NewString(),
 585								SourceType: ai.SourceTypeURL,
 586								URL:        annotation.URLCitation.URL,
 587								Title:      annotation.URLCitation.Title,
 588							}) {
 589								return
 590							}
 591						}
 592					}
 593				}
 594			}
 595		}
 596		err := stream.Err()
 597		if err == nil || errors.Is(err, io.EOF) {
 598			// finished
 599			if isActiveText {
 600				isActiveText = false
 601				if !yield(ai.StreamPart{
 602					Type: ai.StreamPartTypeTextEnd,
 603					ID:   "0",
 604				}) {
 605					return
 606				}
 607			}
 608
 609			// Add logprobs if available
 610			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 611				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
 612			}
 613
 614			// Handle annotations/citations from accumulated response
 615			if len(acc.Choices) > 0 {
 616				for _, annotation := range acc.Choices[0].Message.Annotations {
 617					if annotation.Type == "url_citation" {
 618						if !yield(ai.StreamPart{
 619							Type:       ai.StreamPartTypeSource,
 620							ID:         acc.ID,
 621							SourceType: ai.SourceTypeURL,
 622							URL:        annotation.URLCitation.URL,
 623							Title:      annotation.URLCitation.Title,
 624						}) {
 625							return
 626						}
 627					}
 628				}
 629			}
 630
 631			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 632			yield(ai.StreamPart{
 633				Type:         ai.StreamPartTypeFinish,
 634				Usage:        usage,
 635				FinishReason: finishReason,
 636				ProviderMetadata: ai.ProviderMetadata{
 637					Name: streamProviderMetadata,
 638				},
 639			})
 640			return
 641		} else {
 642			yield(ai.StreamPart{
 643				Type:  ai.StreamPartTypeError,
 644				Error: o.handleError(err),
 645			})
 646			return
 647		}
 648	}, nil
 649}
 650
 651func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
 652	var options ProviderOptions
 653	if err := ai.ParseOptions(data, &options); err != nil {
 654		return nil, err
 655	}
 656	return &options, nil
 657}
 658
 659func (o *provider) Name() string {
 660	return Name
 661}
 662
 663func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 664	if call.ProviderOptions == nil {
 665		return nil, nil
 666	}
 667	var warnings []ai.CallWarning
 668	providerOptions := &ProviderOptions{}
 669	if v, ok := call.ProviderOptions[Name]; ok {
 670		providerOptions, ok = v.(*ProviderOptions)
 671		if !ok {
 672			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 673		}
 674	}
 675
 676	if providerOptions.LogitBias != nil {
 677		params.LogitBias = providerOptions.LogitBias
 678	}
 679	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 680		providerOptions.LogProbs = nil
 681	}
 682	if providerOptions.LogProbs != nil {
 683		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 684	}
 685	if providerOptions.TopLogProbs != nil {
 686		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 687	}
 688	if providerOptions.User != nil {
 689		params.User = param.NewOpt(*providerOptions.User)
 690	}
 691	if providerOptions.ParallelToolCalls != nil {
 692		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 693	}
 694	if providerOptions.MaxCompletionTokens != nil {
 695		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 696	}
 697
 698	if providerOptions.TextVerbosity != nil {
 699		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 700	}
 701	if providerOptions.Prediction != nil {
 702		// Convert map[string]any to ChatCompletionPredictionContentParam
 703		if content, ok := providerOptions.Prediction["content"]; ok {
 704			if contentStr, ok := content.(string); ok {
 705				params.Prediction = openai.ChatCompletionPredictionContentParam{
 706					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 707						OfString: param.NewOpt(contentStr),
 708					},
 709				}
 710			}
 711		}
 712	}
 713	if providerOptions.Store != nil {
 714		params.Store = param.NewOpt(*providerOptions.Store)
 715	}
 716	if providerOptions.Metadata != nil {
 717		// Convert map[string]any to map[string]string
 718		metadata := make(map[string]string)
 719		for k, v := range providerOptions.Metadata {
 720			if str, ok := v.(string); ok {
 721				metadata[k] = str
 722			}
 723		}
 724		params.Metadata = metadata
 725	}
 726	if providerOptions.PromptCacheKey != nil {
 727		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 728	}
 729	if providerOptions.SafetyIdentifier != nil {
 730		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 731	}
 732	if providerOptions.ServiceTier != nil {
 733		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 734	}
 735
 736	if providerOptions.ReasoningEffort != nil {
 737		switch *providerOptions.ReasoningEffort {
 738		case ReasoningEffortMinimal:
 739			params.ReasoningEffort = shared.ReasoningEffortMinimal
 740		case ReasoningEffortLow:
 741			params.ReasoningEffort = shared.ReasoningEffortLow
 742		case ReasoningEffortMedium:
 743			params.ReasoningEffort = shared.ReasoningEffortMedium
 744		case ReasoningEffortHigh:
 745			params.ReasoningEffort = shared.ReasoningEffortHigh
 746		default:
 747			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 748		}
 749	}
 750
 751	if isReasoningModel(model.Model()) {
 752		if providerOptions.LogitBias != nil {
 753			params.LogitBias = nil
 754			warnings = append(warnings, ai.CallWarning{
 755				Type:    ai.CallWarningTypeUnsupportedSetting,
 756				Setting: "LogitBias",
 757				Message: "LogitBias is not supported for reasoning models",
 758			})
 759		}
 760		if providerOptions.LogProbs != nil {
 761			params.Logprobs = param.Opt[bool]{}
 762			warnings = append(warnings, ai.CallWarning{
 763				Type:    ai.CallWarningTypeUnsupportedSetting,
 764				Setting: "Logprobs",
 765				Message: "Logprobs is not supported for reasoning models",
 766			})
 767		}
 768		if providerOptions.TopLogProbs != nil {
 769			params.TopLogprobs = param.Opt[int64]{}
 770			warnings = append(warnings, ai.CallWarning{
 771				Type:    ai.CallWarningTypeUnsupportedSetting,
 772				Setting: "TopLogprobs",
 773				Message: "TopLogprobs is not supported for reasoning models",
 774			})
 775		}
 776	}
 777
 778	// Handle service tier validation
 779	if providerOptions.ServiceTier != nil {
 780		serviceTier := *providerOptions.ServiceTier
 781		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
 782			params.ServiceTier = ""
 783			warnings = append(warnings, ai.CallWarning{
 784				Type:    ai.CallWarningTypeUnsupportedSetting,
 785				Setting: "ServiceTier",
 786				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 787			})
 788		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
 789			params.ServiceTier = ""
 790			warnings = append(warnings, ai.CallWarning{
 791				Type:    ai.CallWarningTypeUnsupportedSetting,
 792				Setting: "ServiceTier",
 793				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 794			})
 795		}
 796	}
 797	return warnings, nil
 798}
 799
 800func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 801	switch finishReason {
 802	case "stop":
 803		return ai.FinishReasonStop
 804	case "length":
 805		return ai.FinishReasonLength
 806	case "content_filter":
 807		return ai.FinishReasonContentFilter
 808	case "function_call", "tool_calls":
 809		return ai.FinishReasonToolCalls
 810	default:
 811		return ai.FinishReasonUnknown
 812	}
 813}
 814
 815func isReasoningModel(modelID string) bool {
 816	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 817}
 818
 819func isSearchPreviewModel(modelID string) bool {
 820	return strings.Contains(modelID, "search-preview")
 821}
 822
 823func supportsFlexProcessing(modelID string) bool {
 824	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 825}
 826
 827func supportsPriorityProcessing(modelID string) bool {
 828	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 829		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 830		strings.HasPrefix(modelID, "o4-mini")
 831}
 832
 833func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 834	for _, tool := range tools {
 835		if tool.GetType() == ai.ToolTypeFunction {
 836			ft, ok := tool.(ai.FunctionTool)
 837			if !ok {
 838				continue
 839			}
 840			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 841				OfFunction: &openai.ChatCompletionFunctionToolParam{
 842					Function: shared.FunctionDefinitionParam{
 843						Name:        ft.Name,
 844						Description: param.NewOpt(ft.Description),
 845						Parameters:  openai.FunctionParameters(ft.InputSchema),
 846						Strict:      param.NewOpt(false),
 847					},
 848					Type: "function",
 849				},
 850			})
 851			continue
 852		}
 853
 854		// TODO: handle provider tool calls
 855		warnings = append(warnings, ai.CallWarning{
 856			Type:    ai.CallWarningTypeUnsupportedTool,
 857			Tool:    tool,
 858			Message: "tool is not supported",
 859		})
 860	}
 861	if toolChoice == nil {
 862		return openAiTools, openAiToolChoice, warnings
 863	}
 864
 865	switch *toolChoice {
 866	case ai.ToolChoiceAuto:
 867		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 868			OfAuto: param.NewOpt("auto"),
 869		}
 870	case ai.ToolChoiceNone:
 871		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 872			OfAuto: param.NewOpt("none"),
 873		}
 874	default:
 875		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 876			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 877				Type: "function",
 878				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 879					Name: string(*toolChoice),
 880				},
 881			},
 882		}
 883	}
 884	return openAiTools, openAiToolChoice, warnings
 885}
 886
 887func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 888	var messages []openai.ChatCompletionMessageParamUnion
 889	var warnings []ai.CallWarning
 890	for _, msg := range prompt {
 891		switch msg.Role {
 892		case ai.MessageRoleSystem:
 893			var systemPromptParts []string
 894			for _, c := range msg.Content {
 895				if c.GetType() != ai.ContentTypeText {
 896					warnings = append(warnings, ai.CallWarning{
 897						Type:    ai.CallWarningTypeOther,
 898						Message: "system prompt can only have text content",
 899					})
 900					continue
 901				}
 902				textPart, ok := ai.AsContentType[ai.TextPart](c)
 903				if !ok {
 904					warnings = append(warnings, ai.CallWarning{
 905						Type:    ai.CallWarningTypeOther,
 906						Message: "system prompt text part does not have the right type",
 907					})
 908					continue
 909				}
 910				text := textPart.Text
 911				if strings.TrimSpace(text) != "" {
 912					systemPromptParts = append(systemPromptParts, textPart.Text)
 913				}
 914			}
 915			if len(systemPromptParts) == 0 {
 916				warnings = append(warnings, ai.CallWarning{
 917					Type:    ai.CallWarningTypeOther,
 918					Message: "system prompt has no text parts",
 919				})
 920				continue
 921			}
 922			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 923		case ai.MessageRoleUser:
 924			// simple user message just text content
 925			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 926				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 927				if !ok {
 928					warnings = append(warnings, ai.CallWarning{
 929						Type:    ai.CallWarningTypeOther,
 930						Message: "user message text part does not have the right type",
 931					})
 932					continue
 933				}
 934				messages = append(messages, openai.UserMessage(textPart.Text))
 935				continue
 936			}
 937			// text content and attachments
 938			// for now we only support image content later we need to check
 939			// TODO: add the supported media types to the language model so we
 940			//  can use that to validate the data here.
 941			var content []openai.ChatCompletionContentPartUnionParam
 942			for _, c := range msg.Content {
 943				switch c.GetType() {
 944				case ai.ContentTypeText:
 945					textPart, ok := ai.AsContentType[ai.TextPart](c)
 946					if !ok {
 947						warnings = append(warnings, ai.CallWarning{
 948							Type:    ai.CallWarningTypeOther,
 949							Message: "user message text part does not have the right type",
 950						})
 951						continue
 952					}
 953					content = append(content, openai.ChatCompletionContentPartUnionParam{
 954						OfText: &openai.ChatCompletionContentPartTextParam{
 955							Text: textPart.Text,
 956						},
 957					})
 958				case ai.ContentTypeFile:
 959					filePart, ok := ai.AsContentType[ai.FilePart](c)
 960					if !ok {
 961						warnings = append(warnings, ai.CallWarning{
 962							Type:    ai.CallWarningTypeOther,
 963							Message: "user message file part does not have the right type",
 964						})
 965						continue
 966					}
 967
 968					switch {
 969					case strings.HasPrefix(filePart.MediaType, "image/"):
 970						// Handle image files
 971						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 972						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 973						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 974
 975						// Check for provider-specific options like image detail
 976						if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
 977							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
 978								imageURL.Detail = detail.ImageDetail
 979							}
 980						}
 981
 982						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 983						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 984
 985					case filePart.MediaType == "audio/wav":
 986						// Handle WAV audio files
 987						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 988						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 989							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 990								Data:   base64Encoded,
 991								Format: "wav",
 992							},
 993						}
 994						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 995
 996					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 997						// Handle MP3 audio files
 998						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 999						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
1000							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
1001								Data:   base64Encoded,
1002								Format: "mp3",
1003							},
1004						}
1005						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
1006
1007					case filePart.MediaType == "application/pdf":
1008						// Handle PDF files
1009						dataStr := string(filePart.Data)
1010
1011						// Check if data looks like a file ID (starts with "file-")
1012						if strings.HasPrefix(dataStr, "file-") {
1013							fileBlock := openai.ChatCompletionContentPartFileParam{
1014								File: openai.ChatCompletionContentPartFileFileParam{
1015									FileID: param.NewOpt(dataStr),
1016								},
1017							}
1018							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1019						} else {
1020							// Handle as base64 data
1021							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1022							data := "data:application/pdf;base64," + base64Encoded
1023
1024							filename := filePart.Filename
1025							if filename == "" {
1026								// Generate default filename based on content index
1027								filename = fmt.Sprintf("part-%d.pdf", len(content))
1028							}
1029
1030							fileBlock := openai.ChatCompletionContentPartFileParam{
1031								File: openai.ChatCompletionContentPartFileFileParam{
1032									Filename: param.NewOpt(filename),
1033									FileData: param.NewOpt(data),
1034								},
1035							}
1036							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1037						}
1038
1039					default:
1040						warnings = append(warnings, ai.CallWarning{
1041							Type:    ai.CallWarningTypeOther,
1042							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1043						})
1044					}
1045				}
1046			}
1047			messages = append(messages, openai.UserMessage(content))
1048		case ai.MessageRoleAssistant:
1049			// simple assistant message just text content
1050			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1051				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1052				if !ok {
1053					warnings = append(warnings, ai.CallWarning{
1054						Type:    ai.CallWarningTypeOther,
1055						Message: "assistant message text part does not have the right type",
1056					})
1057					continue
1058				}
1059				messages = append(messages, openai.AssistantMessage(textPart.Text))
1060				continue
1061			}
1062			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1063				Role: "assistant",
1064			}
1065			for _, c := range msg.Content {
1066				switch c.GetType() {
1067				case ai.ContentTypeText:
1068					textPart, ok := ai.AsContentType[ai.TextPart](c)
1069					if !ok {
1070						warnings = append(warnings, ai.CallWarning{
1071							Type:    ai.CallWarningTypeOther,
1072							Message: "assistant message text part does not have the right type",
1073						})
1074						continue
1075					}
1076					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1077						OfString: param.NewOpt(textPart.Text),
1078					}
1079				case ai.ContentTypeToolCall:
1080					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1081					if !ok {
1082						warnings = append(warnings, ai.CallWarning{
1083							Type:    ai.CallWarningTypeOther,
1084							Message: "assistant message tool part does not have the right type",
1085						})
1086						continue
1087					}
1088					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1089						openai.ChatCompletionMessageToolCallUnionParam{
1090							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1091								ID:   toolCallPart.ToolCallID,
1092								Type: "function",
1093								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1094									Name:      toolCallPart.ToolName,
1095									Arguments: toolCallPart.Input,
1096								},
1097							},
1098						})
1099				}
1100			}
1101			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1102				OfAssistant: &assistantMsg,
1103			})
1104		case ai.MessageRoleTool:
1105			for _, c := range msg.Content {
1106				if c.GetType() != ai.ContentTypeToolResult {
1107					warnings = append(warnings, ai.CallWarning{
1108						Type:    ai.CallWarningTypeOther,
1109						Message: "tool message can only have tool result content",
1110					})
1111					continue
1112				}
1113
1114				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1115				if !ok {
1116					warnings = append(warnings, ai.CallWarning{
1117						Type:    ai.CallWarningTypeOther,
1118						Message: "tool message result part does not have the right type",
1119					})
1120					continue
1121				}
1122
1123				switch toolResultPart.Output.GetType() {
1124				case ai.ToolResultContentTypeText:
1125					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1126					if !ok {
1127						warnings = append(warnings, ai.CallWarning{
1128							Type:    ai.CallWarningTypeOther,
1129							Message: "tool result output does not have the right type",
1130						})
1131						continue
1132					}
1133					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1134				case ai.ToolResultContentTypeError:
1135					// TODO: check if better handling is needed
1136					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1137					if !ok {
1138						warnings = append(warnings, ai.CallWarning{
1139							Type:    ai.CallWarningTypeOther,
1140							Message: "tool result output does not have the right type",
1141						})
1142						continue
1143					}
1144					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1145				}
1146			}
1147		}
1148	}
1149	return messages, warnings
1150}
1151
1152// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1153func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1154	var annotations []openai.ChatCompletionMessageAnnotation
1155
1156	// Parse the raw JSON to extract annotations
1157	var deltaData map[string]any
1158	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1159		return annotations
1160	}
1161
1162	// Check if annotations exist in the delta
1163	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1164		for _, annotationData := range annotationsData {
1165			if annotationMap, ok := annotationData.(map[string]any); ok {
1166				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1167					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1168						annotation := openai.ChatCompletionMessageAnnotation{
1169							Type: "url_citation",
1170							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1171								URL:   urlCitationData["url"].(string),
1172								Title: urlCitationData["title"].(string),
1173							},
1174						}
1175						annotations = append(annotations, annotation)
1176					}
1177				}
1178			}
1179		}
1180	}
1181
1182	return annotations
1183}