1package openai
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/base64"
   7	"encoding/json"
   8	"errors"
   9	"fmt"
  10	"io"
  11	"maps"
  12	"strings"
  13
  14	"github.com/charmbracelet/fantasy/ai"
  15	xjson "github.com/charmbracelet/x/json"
  16	"github.com/google/uuid"
  17	"github.com/openai/openai-go/v2"
  18	"github.com/openai/openai-go/v2/option"
  19	"github.com/openai/openai-go/v2/packages/param"
  20	"github.com/openai/openai-go/v2/shared"
  21)
  22
  23const (
  24	ProviderName = "openai"
  25	DefaultURL   = "https://api.openai.com/v1"
  26)
  27
  28type provider struct {
  29	options options
  30}
  31
  32type options struct {
  33	baseURL      string
  34	apiKey       string
  35	organization string
  36	project      string
  37	name         string
  38	headers      map[string]string
  39	client       option.HTTPClient
  40}
  41
  42type Option = func(*options)
  43
  44func New(opts ...Option) ai.Provider {
  45	providerOptions := options{
  46		headers: map[string]string{},
  47	}
  48	for _, o := range opts {
  49		o(&providerOptions)
  50	}
  51
  52	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
  53	providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
  54
  55	if providerOptions.organization != "" {
  56		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
  57	}
  58	if providerOptions.project != "" {
  59		providerOptions.headers["OpenAi-Project"] = providerOptions.project
  60	}
  61
  62	return &provider{options: providerOptions}
  63}
  64
  65func WithBaseURL(baseURL string) Option {
  66	return func(o *options) {
  67		o.baseURL = baseURL
  68	}
  69}
  70
  71func WithAPIKey(apiKey string) Option {
  72	return func(o *options) {
  73		o.apiKey = apiKey
  74	}
  75}
  76
  77func WithOrganization(organization string) Option {
  78	return func(o *options) {
  79		o.organization = organization
  80	}
  81}
  82
  83func WithProject(project string) Option {
  84	return func(o *options) {
  85		o.project = project
  86	}
  87}
  88
  89func WithName(name string) Option {
  90	return func(o *options) {
  91		o.name = name
  92	}
  93}
  94
  95func WithHeaders(headers map[string]string) Option {
  96	return func(o *options) {
  97		maps.Copy(o.headers, headers)
  98	}
  99}
 100
 101func WithHTTPClient(client option.HTTPClient) Option {
 102	return func(o *options) {
 103		o.client = client
 104	}
 105}
 106
 107// LanguageModel implements ai.Provider.
 108func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 109	openaiClientOptions := []option.RequestOption{}
 110	if o.options.apiKey != "" {
 111		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
 112	}
 113	if o.options.baseURL != "" {
 114		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
 115	}
 116
 117	for key, value := range o.options.headers {
 118		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 119	}
 120
 121	if o.options.client != nil {
 122		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
 123	}
 124
 125	return languageModel{
 126		modelID:  modelID,
 127		provider: fmt.Sprintf("%s.chat", o.options.name),
 128		options:  o.options,
 129		client:   openai.NewClient(openaiClientOptions...),
 130	}, nil
 131}
 132
 133type languageModel struct {
 134	provider string
 135	modelID  string
 136	client   openai.Client
 137	options  options
 138}
 139
 140// Model implements ai.LanguageModel.
 141func (o languageModel) Model() string {
 142	return o.modelID
 143}
 144
 145// Provider implements ai.LanguageModel.
 146func (o languageModel) Provider() string {
 147	return o.provider
 148}
 149
 150func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 151	params := &openai.ChatCompletionNewParams{}
 152	messages, warnings := toPrompt(call.Prompt)
 153	providerOptions := &ProviderOptions{}
 154	if v, ok := call.ProviderOptions[Name]; ok {
 155		providerOptions, ok = v.(*ProviderOptions)
 156		if !ok {
 157			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 158		}
 159	}
 160	if call.TopK != nil {
 161		warnings = append(warnings, ai.CallWarning{
 162			Type:    ai.CallWarningTypeUnsupportedSetting,
 163			Setting: "top_k",
 164		})
 165	}
 166	params.Messages = messages
 167	params.Model = o.modelID
 168	if providerOptions.LogitBias != nil {
 169		params.LogitBias = providerOptions.LogitBias
 170	}
 171	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 172		providerOptions.LogProbs = nil
 173	}
 174	if providerOptions.LogProbs != nil {
 175		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 176	}
 177	if providerOptions.TopLogProbs != nil {
 178		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 179	}
 180	if providerOptions.User != nil {
 181		params.User = param.NewOpt(*providerOptions.User)
 182	}
 183	if providerOptions.ParallelToolCalls != nil {
 184		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 185	}
 186
 187	if call.MaxOutputTokens != nil {
 188		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
 189	}
 190	if call.Temperature != nil {
 191		params.Temperature = param.NewOpt(*call.Temperature)
 192	}
 193	if call.TopP != nil {
 194		params.TopP = param.NewOpt(*call.TopP)
 195	}
 196	if call.FrequencyPenalty != nil {
 197		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
 198	}
 199	if call.PresencePenalty != nil {
 200		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
 201	}
 202
 203	if providerOptions.MaxCompletionTokens != nil {
 204		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 205	}
 206
 207	if providerOptions.TextVerbosity != nil {
 208		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 209	}
 210	if providerOptions.Prediction != nil {
 211		// Convert map[string]any to ChatCompletionPredictionContentParam
 212		if content, ok := providerOptions.Prediction["content"]; ok {
 213			if contentStr, ok := content.(string); ok {
 214				params.Prediction = openai.ChatCompletionPredictionContentParam{
 215					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 216						OfString: param.NewOpt(contentStr),
 217					},
 218				}
 219			}
 220		}
 221	}
 222	if providerOptions.Store != nil {
 223		params.Store = param.NewOpt(*providerOptions.Store)
 224	}
 225	if providerOptions.Metadata != nil {
 226		// Convert map[string]any to map[string]string
 227		metadata := make(map[string]string)
 228		for k, v := range providerOptions.Metadata {
 229			if str, ok := v.(string); ok {
 230				metadata[k] = str
 231			}
 232		}
 233		params.Metadata = metadata
 234	}
 235	if providerOptions.PromptCacheKey != nil {
 236		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 237	}
 238	if providerOptions.SafetyIdentifier != nil {
 239		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 240	}
 241	if providerOptions.ServiceTier != nil {
 242		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 243	}
 244
 245	if providerOptions.ReasoningEffort != nil {
 246		switch *providerOptions.ReasoningEffort {
 247		case ReasoningEffortMinimal:
 248			params.ReasoningEffort = shared.ReasoningEffortMinimal
 249		case ReasoningEffortLow:
 250			params.ReasoningEffort = shared.ReasoningEffortLow
 251		case ReasoningEffortMedium:
 252			params.ReasoningEffort = shared.ReasoningEffortMedium
 253		case ReasoningEffortHigh:
 254			params.ReasoningEffort = shared.ReasoningEffortHigh
 255		default:
 256			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 257		}
 258	}
 259
 260	if isReasoningModel(o.modelID) {
 261		// remove unsupported settings for reasoning models
 262		// see https://platform.openai.com/docs/guides/reasoning#limitations
 263		if call.Temperature != nil {
 264			params.Temperature = param.Opt[float64]{}
 265			warnings = append(warnings, ai.CallWarning{
 266				Type:    ai.CallWarningTypeUnsupportedSetting,
 267				Setting: "temperature",
 268				Details: "temperature is not supported for reasoning models",
 269			})
 270		}
 271		if call.TopP != nil {
 272			params.TopP = param.Opt[float64]{}
 273			warnings = append(warnings, ai.CallWarning{
 274				Type:    ai.CallWarningTypeUnsupportedSetting,
 275				Setting: "TopP",
 276				Details: "TopP is not supported for reasoning models",
 277			})
 278		}
 279		if call.FrequencyPenalty != nil {
 280			params.FrequencyPenalty = param.Opt[float64]{}
 281			warnings = append(warnings, ai.CallWarning{
 282				Type:    ai.CallWarningTypeUnsupportedSetting,
 283				Setting: "FrequencyPenalty",
 284				Details: "FrequencyPenalty is not supported for reasoning models",
 285			})
 286		}
 287		if call.PresencePenalty != nil {
 288			params.PresencePenalty = param.Opt[float64]{}
 289			warnings = append(warnings, ai.CallWarning{
 290				Type:    ai.CallWarningTypeUnsupportedSetting,
 291				Setting: "PresencePenalty",
 292				Details: "PresencePenalty is not supported for reasoning models",
 293			})
 294		}
 295		if providerOptions.LogitBias != nil {
 296			params.LogitBias = nil
 297			warnings = append(warnings, ai.CallWarning{
 298				Type:    ai.CallWarningTypeUnsupportedSetting,
 299				Setting: "LogitBias",
 300				Message: "LogitBias is not supported for reasoning models",
 301			})
 302		}
 303		if providerOptions.LogProbs != nil {
 304			params.Logprobs = param.Opt[bool]{}
 305			warnings = append(warnings, ai.CallWarning{
 306				Type:    ai.CallWarningTypeUnsupportedSetting,
 307				Setting: "Logprobs",
 308				Message: "Logprobs is not supported for reasoning models",
 309			})
 310		}
 311		if providerOptions.TopLogProbs != nil {
 312			params.TopLogprobs = param.Opt[int64]{}
 313			warnings = append(warnings, ai.CallWarning{
 314				Type:    ai.CallWarningTypeUnsupportedSetting,
 315				Setting: "TopLogprobs",
 316				Message: "TopLogprobs is not supported for reasoning models",
 317			})
 318		}
 319
 320		// reasoning models use max_completion_tokens instead of max_tokens
 321		if call.MaxOutputTokens != nil {
 322			if providerOptions.MaxCompletionTokens == nil {
 323				params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
 324			}
 325			params.MaxTokens = param.Opt[int64]{}
 326		}
 327	}
 328
 329	// Handle search preview models
 330	if isSearchPreviewModel(o.modelID) {
 331		if call.Temperature != nil {
 332			params.Temperature = param.Opt[float64]{}
 333			warnings = append(warnings, ai.CallWarning{
 334				Type:    ai.CallWarningTypeUnsupportedSetting,
 335				Setting: "temperature",
 336				Details: "temperature is not supported for the search preview models and has been removed.",
 337			})
 338		}
 339	}
 340
 341	// Handle service tier validation
 342	if providerOptions.ServiceTier != nil {
 343		serviceTier := *providerOptions.ServiceTier
 344		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
 345			params.ServiceTier = ""
 346			warnings = append(warnings, ai.CallWarning{
 347				Type:    ai.CallWarningTypeUnsupportedSetting,
 348				Setting: "ServiceTier",
 349				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 350			})
 351		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
 352			params.ServiceTier = ""
 353			warnings = append(warnings, ai.CallWarning{
 354				Type:    ai.CallWarningTypeUnsupportedSetting,
 355				Setting: "ServiceTier",
 356				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
 357			})
 358		}
 359	}
 360
 361	if len(call.Tools) > 0 {
 362		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
 363		params.Tools = tools
 364		if toolChoice != nil {
 365			params.ToolChoice = *toolChoice
 366		}
 367		warnings = append(warnings, toolWarnings...)
 368	}
 369	return params, warnings, nil
 370}
 371
 372func (o languageModel) handleError(err error) error {
 373	var apiErr *openai.Error
 374	if errors.As(err, &apiErr) {
 375		requestDump := apiErr.DumpRequest(true)
 376		responseDump := apiErr.DumpResponse(true)
 377		headers := map[string]string{}
 378		for k, h := range apiErr.Response.Header {
 379			v := h[len(h)-1]
 380			headers[strings.ToLower(k)] = v
 381		}
 382		return ai.NewAPICallError(
 383			apiErr.Message,
 384			apiErr.Request.URL.String(),
 385			string(requestDump),
 386			apiErr.StatusCode,
 387			headers,
 388			string(responseDump),
 389			apiErr,
 390			false,
 391		)
 392	}
 393	return err
 394}
 395
 396// Generate implements ai.LanguageModel.
 397func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 398	params, warnings, err := o.prepareParams(call)
 399	if err != nil {
 400		return nil, err
 401	}
 402	response, err := o.client.Chat.Completions.New(ctx, *params)
 403	if err != nil {
 404		return nil, o.handleError(err)
 405	}
 406
 407	if len(response.Choices) == 0 {
 408		return nil, errors.New("no response generated")
 409	}
 410	choice := response.Choices[0]
 411	content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
 412	text := choice.Message.Content
 413	if text != "" {
 414		content = append(content, ai.TextContent{
 415			Text: text,
 416		})
 417	}
 418
 419	for _, tc := range choice.Message.ToolCalls {
 420		toolCallID := tc.ID
 421		if toolCallID == "" {
 422			toolCallID = uuid.NewString()
 423		}
 424		content = append(content, ai.ToolCallContent{
 425			ProviderExecuted: false, // TODO: update when handling other tools
 426			ToolCallID:       toolCallID,
 427			ToolName:         tc.Function.Name,
 428			Input:            tc.Function.Arguments,
 429		})
 430	}
 431	// Handle annotations/citations
 432	for _, annotation := range choice.Message.Annotations {
 433		if annotation.Type == "url_citation" {
 434			content = append(content, ai.SourceContent{
 435				SourceType: ai.SourceTypeURL,
 436				ID:         uuid.NewString(),
 437				URL:        annotation.URLCitation.URL,
 438				Title:      annotation.URLCitation.Title,
 439			})
 440		}
 441	}
 442
 443	completionTokenDetails := response.Usage.CompletionTokensDetails
 444	promptTokenDetails := response.Usage.PromptTokensDetails
 445
 446	// Build provider metadata
 447	providerMetadata := &ProviderMetadata{}
 448	// Add logprobs if available
 449	if len(choice.Logprobs.Content) > 0 {
 450		providerMetadata.Logprobs = choice.Logprobs.Content
 451	}
 452
 453	// Add prediction tokens if available
 454	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 455		if completionTokenDetails.AcceptedPredictionTokens > 0 {
 456			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 457		}
 458		if completionTokenDetails.RejectedPredictionTokens > 0 {
 459			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 460		}
 461	}
 462
 463	return &ai.Response{
 464		Content: content,
 465		Usage: ai.Usage{
 466			InputTokens:     response.Usage.PromptTokens,
 467			OutputTokens:    response.Usage.CompletionTokens,
 468			TotalTokens:     response.Usage.TotalTokens,
 469			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 470			CacheReadTokens: promptTokenDetails.CachedTokens,
 471		},
 472		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
 473		ProviderMetadata: ai.ProviderMetadata{
 474			Name: providerMetadata,
 475		},
 476		Warnings: warnings,
 477	}, nil
 478}
 479
 480type toolCall struct {
 481	id          string
 482	name        string
 483	arguments   string
 484	hasFinished bool
 485}
 486
 487// Stream implements ai.LanguageModel.
 488func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
 489	params, warnings, err := o.prepareParams(call)
 490	if err != nil {
 491		return nil, err
 492	}
 493
 494	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
 495		IncludeUsage: openai.Bool(true),
 496	}
 497
 498	stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
 499	isActiveText := false
 500	toolCalls := make(map[int64]toolCall)
 501
 502	// Build provider metadata for streaming
 503	streamProviderMetadata := &ProviderMetadata{}
 504	acc := openai.ChatCompletionAccumulator{}
 505	var usage ai.Usage
 506	return func(yield func(ai.StreamPart) bool) {
 507		if len(warnings) > 0 {
 508			if !yield(ai.StreamPart{
 509				Type:     ai.StreamPartTypeWarnings,
 510				Warnings: warnings,
 511			}) {
 512				return
 513			}
 514		}
 515		for stream.Next() {
 516			chunk := stream.Current()
 517			acc.AddChunk(chunk)
 518			if chunk.Usage.TotalTokens > 0 {
 519				// we do this here because the acc does not add prompt details
 520				completionTokenDetails := chunk.Usage.CompletionTokensDetails
 521				promptTokenDetails := chunk.Usage.PromptTokensDetails
 522				usage = ai.Usage{
 523					InputTokens:     chunk.Usage.PromptTokens,
 524					OutputTokens:    chunk.Usage.CompletionTokens,
 525					TotalTokens:     chunk.Usage.TotalTokens,
 526					ReasoningTokens: completionTokenDetails.ReasoningTokens,
 527					CacheReadTokens: promptTokenDetails.CachedTokens,
 528				}
 529
 530				// Add prediction tokens if available
 531				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 532					if completionTokenDetails.AcceptedPredictionTokens > 0 {
 533						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 534					}
 535					if completionTokenDetails.RejectedPredictionTokens > 0 {
 536						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 537					}
 538				}
 539			}
 540			if len(chunk.Choices) == 0 {
 541				continue
 542			}
 543			for _, choice := range chunk.Choices {
 544				switch {
 545				case choice.Delta.Content != "":
 546					if !isActiveText {
 547						isActiveText = true
 548						if !yield(ai.StreamPart{
 549							Type: ai.StreamPartTypeTextStart,
 550							ID:   "0",
 551						}) {
 552							return
 553						}
 554					}
 555					if !yield(ai.StreamPart{
 556						Type:  ai.StreamPartTypeTextDelta,
 557						ID:    "0",
 558						Delta: choice.Delta.Content,
 559					}) {
 560						return
 561					}
 562				case len(choice.Delta.ToolCalls) > 0:
 563					if isActiveText {
 564						isActiveText = false
 565						if !yield(ai.StreamPart{
 566							Type: ai.StreamPartTypeTextEnd,
 567							ID:   "0",
 568						}) {
 569							return
 570						}
 571					}
 572
 573					for _, toolCallDelta := range choice.Delta.ToolCalls {
 574						if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
 575							if existingToolCall.hasFinished {
 576								continue
 577							}
 578							if toolCallDelta.Function.Arguments != "" {
 579								existingToolCall.arguments += toolCallDelta.Function.Arguments
 580							}
 581							if !yield(ai.StreamPart{
 582								Type:  ai.StreamPartTypeToolInputDelta,
 583								ID:    existingToolCall.id,
 584								Delta: toolCallDelta.Function.Arguments,
 585							}) {
 586								return
 587							}
 588							toolCalls[toolCallDelta.Index] = existingToolCall
 589							if xjson.IsValid(existingToolCall.arguments) {
 590								if !yield(ai.StreamPart{
 591									Type: ai.StreamPartTypeToolInputEnd,
 592									ID:   existingToolCall.id,
 593								}) {
 594									return
 595								}
 596
 597								if !yield(ai.StreamPart{
 598									Type:          ai.StreamPartTypeToolCall,
 599									ID:            existingToolCall.id,
 600									ToolCallName:  existingToolCall.name,
 601									ToolCallInput: existingToolCall.arguments,
 602								}) {
 603									return
 604								}
 605								existingToolCall.hasFinished = true
 606								toolCalls[toolCallDelta.Index] = existingToolCall
 607							}
 608						} else {
 609							// Does not exist
 610							var err error
 611							if toolCallDelta.Type != "function" {
 612								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
 613							}
 614							if toolCallDelta.ID == "" {
 615								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
 616							}
 617							if toolCallDelta.Function.Name == "" {
 618								err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
 619							}
 620							if err != nil {
 621								yield(ai.StreamPart{
 622									Type:  ai.StreamPartTypeError,
 623									Error: o.handleError(stream.Err()),
 624								})
 625								return
 626							}
 627
 628							if !yield(ai.StreamPart{
 629								Type:         ai.StreamPartTypeToolInputStart,
 630								ID:           toolCallDelta.ID,
 631								ToolCallName: toolCallDelta.Function.Name,
 632							}) {
 633								return
 634							}
 635							toolCalls[toolCallDelta.Index] = toolCall{
 636								id:        toolCallDelta.ID,
 637								name:      toolCallDelta.Function.Name,
 638								arguments: toolCallDelta.Function.Arguments,
 639							}
 640
 641							exTc := toolCalls[toolCallDelta.Index]
 642							if exTc.arguments != "" {
 643								if !yield(ai.StreamPart{
 644									Type:  ai.StreamPartTypeToolInputDelta,
 645									ID:    exTc.id,
 646									Delta: exTc.arguments,
 647								}) {
 648									return
 649								}
 650								if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
 651									if !yield(ai.StreamPart{
 652										Type: ai.StreamPartTypeToolInputEnd,
 653										ID:   toolCallDelta.ID,
 654									}) {
 655										return
 656									}
 657
 658									if !yield(ai.StreamPart{
 659										Type:          ai.StreamPartTypeToolCall,
 660										ID:            exTc.id,
 661										ToolCallName:  exTc.name,
 662										ToolCallInput: exTc.arguments,
 663									}) {
 664										return
 665									}
 666									exTc.hasFinished = true
 667									toolCalls[toolCallDelta.Index] = exTc
 668								}
 669							}
 670							continue
 671						}
 672					}
 673				}
 674			}
 675
 676			// Check for annotations in the delta's raw JSON
 677			for _, choice := range chunk.Choices {
 678				if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
 679					for _, annotation := range annotations {
 680						if annotation.Type == "url_citation" {
 681							if !yield(ai.StreamPart{
 682								Type:       ai.StreamPartTypeSource,
 683								ID:         uuid.NewString(),
 684								SourceType: ai.SourceTypeURL,
 685								URL:        annotation.URLCitation.URL,
 686								Title:      annotation.URLCitation.Title,
 687							}) {
 688								return
 689							}
 690						}
 691					}
 692				}
 693			}
 694		}
 695		err := stream.Err()
 696		if err == nil || errors.Is(err, io.EOF) {
 697			// finished
 698			if isActiveText {
 699				isActiveText = false
 700				if !yield(ai.StreamPart{
 701					Type: ai.StreamPartTypeTextEnd,
 702					ID:   "0",
 703				}) {
 704					return
 705				}
 706			}
 707
 708			// Add logprobs if available
 709			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
 710				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
 711			}
 712
 713			// Handle annotations/citations from accumulated response
 714			if len(acc.Choices) > 0 {
 715				for _, annotation := range acc.Choices[0].Message.Annotations {
 716					if annotation.Type == "url_citation" {
 717						if !yield(ai.StreamPart{
 718							Type:       ai.StreamPartTypeSource,
 719							ID:         acc.ID,
 720							SourceType: ai.SourceTypeURL,
 721							URL:        annotation.URLCitation.URL,
 722							Title:      annotation.URLCitation.Title,
 723						}) {
 724							return
 725						}
 726					}
 727				}
 728			}
 729
 730			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 731			yield(ai.StreamPart{
 732				Type:         ai.StreamPartTypeFinish,
 733				Usage:        usage,
 734				FinishReason: finishReason,
 735				ProviderMetadata: ai.ProviderMetadata{
 736					Name: streamProviderMetadata,
 737				},
 738			})
 739			return
 740		} else {
 741			yield(ai.StreamPart{
 742				Type:  ai.StreamPartTypeError,
 743				Error: o.handleError(err),
 744			})
 745			return
 746		}
 747	}, nil
 748}
 749
 750func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
 751	var options ProviderOptions
 752	if err := ai.ParseOptions(data, &options); err != nil {
 753		return nil, err
 754	}
 755	return &options, nil
 756}
 757
 758func (o *provider) Name() string {
 759	return Name
 760}
 761
 762func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
 763	switch finishReason {
 764	case "stop":
 765		return ai.FinishReasonStop
 766	case "length":
 767		return ai.FinishReasonLength
 768	case "content_filter":
 769		return ai.FinishReasonContentFilter
 770	case "function_call", "tool_calls":
 771		return ai.FinishReasonToolCalls
 772	default:
 773		return ai.FinishReasonUnknown
 774	}
 775}
 776
 777func isReasoningModel(modelID string) bool {
 778	return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
 779}
 780
 781func isSearchPreviewModel(modelID string) bool {
 782	return strings.Contains(modelID, "search-preview")
 783}
 784
 785func supportsFlexProcessing(modelID string) bool {
 786	return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
 787}
 788
 789func supportsPriorityProcessing(modelID string) bool {
 790	return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
 791		strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
 792		strings.HasPrefix(modelID, "o4-mini")
 793}
 794
 795func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
 796	for _, tool := range tools {
 797		if tool.GetType() == ai.ToolTypeFunction {
 798			ft, ok := tool.(ai.FunctionTool)
 799			if !ok {
 800				continue
 801			}
 802			openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
 803				OfFunction: &openai.ChatCompletionFunctionToolParam{
 804					Function: shared.FunctionDefinitionParam{
 805						Name:        ft.Name,
 806						Description: param.NewOpt(ft.Description),
 807						Parameters:  openai.FunctionParameters(ft.InputSchema),
 808						Strict:      param.NewOpt(false),
 809					},
 810					Type: "function",
 811				},
 812			})
 813			continue
 814		}
 815
 816		// TODO: handle provider tool calls
 817		warnings = append(warnings, ai.CallWarning{
 818			Type:    ai.CallWarningTypeUnsupportedTool,
 819			Tool:    tool,
 820			Message: "tool is not supported",
 821		})
 822	}
 823	if toolChoice == nil {
 824		return openAiTools, openAiToolChoice, warnings
 825	}
 826
 827	switch *toolChoice {
 828	case ai.ToolChoiceAuto:
 829		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 830			OfAuto: param.NewOpt("auto"),
 831		}
 832	case ai.ToolChoiceNone:
 833		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 834			OfAuto: param.NewOpt("none"),
 835		}
 836	default:
 837		openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
 838			OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
 839				Type: "function",
 840				Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
 841					Name: string(*toolChoice),
 842				},
 843			},
 844		}
 845	}
 846	return openAiTools, openAiToolChoice, warnings
 847}
 848
 849func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
 850	var messages []openai.ChatCompletionMessageParamUnion
 851	var warnings []ai.CallWarning
 852	for _, msg := range prompt {
 853		switch msg.Role {
 854		case ai.MessageRoleSystem:
 855			var systemPromptParts []string
 856			for _, c := range msg.Content {
 857				if c.GetType() != ai.ContentTypeText {
 858					warnings = append(warnings, ai.CallWarning{
 859						Type:    ai.CallWarningTypeOther,
 860						Message: "system prompt can only have text content",
 861					})
 862					continue
 863				}
 864				textPart, ok := ai.AsContentType[ai.TextPart](c)
 865				if !ok {
 866					warnings = append(warnings, ai.CallWarning{
 867						Type:    ai.CallWarningTypeOther,
 868						Message: "system prompt text part does not have the right type",
 869					})
 870					continue
 871				}
 872				text := textPart.Text
 873				if strings.TrimSpace(text) != "" {
 874					systemPromptParts = append(systemPromptParts, textPart.Text)
 875				}
 876			}
 877			if len(systemPromptParts) == 0 {
 878				warnings = append(warnings, ai.CallWarning{
 879					Type:    ai.CallWarningTypeOther,
 880					Message: "system prompt has no text parts",
 881				})
 882				continue
 883			}
 884			messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
 885		case ai.MessageRoleUser:
 886			// simple user message just text content
 887			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
 888				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
 889				if !ok {
 890					warnings = append(warnings, ai.CallWarning{
 891						Type:    ai.CallWarningTypeOther,
 892						Message: "user message text part does not have the right type",
 893					})
 894					continue
 895				}
 896				messages = append(messages, openai.UserMessage(textPart.Text))
 897				continue
 898			}
 899			// text content and attachments
 900			// for now we only support image content later we need to check
 901			// TODO: add the supported media types to the language model so we
 902			//  can use that to validate the data here.
 903			var content []openai.ChatCompletionContentPartUnionParam
 904			for _, c := range msg.Content {
 905				switch c.GetType() {
 906				case ai.ContentTypeText:
 907					textPart, ok := ai.AsContentType[ai.TextPart](c)
 908					if !ok {
 909						warnings = append(warnings, ai.CallWarning{
 910							Type:    ai.CallWarningTypeOther,
 911							Message: "user message text part does not have the right type",
 912						})
 913						continue
 914					}
 915					content = append(content, openai.ChatCompletionContentPartUnionParam{
 916						OfText: &openai.ChatCompletionContentPartTextParam{
 917							Text: textPart.Text,
 918						},
 919					})
 920				case ai.ContentTypeFile:
 921					filePart, ok := ai.AsContentType[ai.FilePart](c)
 922					if !ok {
 923						warnings = append(warnings, ai.CallWarning{
 924							Type:    ai.CallWarningTypeOther,
 925							Message: "user message file part does not have the right type",
 926						})
 927						continue
 928					}
 929
 930					switch {
 931					case strings.HasPrefix(filePart.MediaType, "image/"):
 932						// Handle image files
 933						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 934						data := "data:" + filePart.MediaType + ";base64," + base64Encoded
 935						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 936
 937						// Check for provider-specific options like image detail
 938						if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
 939							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
 940								imageURL.Detail = detail.ImageDetail
 941							}
 942						}
 943
 944						imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 945						content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 946
 947					case filePart.MediaType == "audio/wav":
 948						// Handle WAV audio files
 949						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 950						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 951							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 952								Data:   base64Encoded,
 953								Format: "wav",
 954							},
 955						}
 956						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 957
 958					case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
 959						// Handle MP3 audio files
 960						base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 961						audioBlock := openai.ChatCompletionContentPartInputAudioParam{
 962							InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
 963								Data:   base64Encoded,
 964								Format: "mp3",
 965							},
 966						}
 967						content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
 968
 969					case filePart.MediaType == "application/pdf":
 970						// Handle PDF files
 971						dataStr := string(filePart.Data)
 972
 973						// Check if data looks like a file ID (starts with "file-")
 974						if strings.HasPrefix(dataStr, "file-") {
 975							fileBlock := openai.ChatCompletionContentPartFileParam{
 976								File: openai.ChatCompletionContentPartFileFileParam{
 977									FileID: param.NewOpt(dataStr),
 978								},
 979							}
 980							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 981						} else {
 982							// Handle as base64 data
 983							base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
 984							data := "data:application/pdf;base64," + base64Encoded
 985
 986							filename := filePart.Filename
 987							if filename == "" {
 988								// Generate default filename based on content index
 989								filename = fmt.Sprintf("part-%d.pdf", len(content))
 990							}
 991
 992							fileBlock := openai.ChatCompletionContentPartFileParam{
 993								File: openai.ChatCompletionContentPartFileFileParam{
 994									Filename: param.NewOpt(filename),
 995									FileData: param.NewOpt(data),
 996								},
 997							}
 998							content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
 999						}
1000
1001					default:
1002						warnings = append(warnings, ai.CallWarning{
1003							Type:    ai.CallWarningTypeOther,
1004							Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1005						})
1006					}
1007				}
1008			}
1009			messages = append(messages, openai.UserMessage(content))
1010		case ai.MessageRoleAssistant:
1011			// simple assistant message just text content
1012			if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1013				textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1014				if !ok {
1015					warnings = append(warnings, ai.CallWarning{
1016						Type:    ai.CallWarningTypeOther,
1017						Message: "assistant message text part does not have the right type",
1018					})
1019					continue
1020				}
1021				messages = append(messages, openai.AssistantMessage(textPart.Text))
1022				continue
1023			}
1024			assistantMsg := openai.ChatCompletionAssistantMessageParam{
1025				Role: "assistant",
1026			}
1027			for _, c := range msg.Content {
1028				switch c.GetType() {
1029				case ai.ContentTypeText:
1030					textPart, ok := ai.AsContentType[ai.TextPart](c)
1031					if !ok {
1032						warnings = append(warnings, ai.CallWarning{
1033							Type:    ai.CallWarningTypeOther,
1034							Message: "assistant message text part does not have the right type",
1035						})
1036						continue
1037					}
1038					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1039						OfString: param.NewOpt(textPart.Text),
1040					}
1041				case ai.ContentTypeToolCall:
1042					toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1043					if !ok {
1044						warnings = append(warnings, ai.CallWarning{
1045							Type:    ai.CallWarningTypeOther,
1046							Message: "assistant message tool part does not have the right type",
1047						})
1048						continue
1049					}
1050					assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1051						openai.ChatCompletionMessageToolCallUnionParam{
1052							OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1053								ID:   toolCallPart.ToolCallID,
1054								Type: "function",
1055								Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1056									Name:      toolCallPart.ToolName,
1057									Arguments: toolCallPart.Input,
1058								},
1059							},
1060						})
1061				}
1062			}
1063			messages = append(messages, openai.ChatCompletionMessageParamUnion{
1064				OfAssistant: &assistantMsg,
1065			})
1066		case ai.MessageRoleTool:
1067			for _, c := range msg.Content {
1068				if c.GetType() != ai.ContentTypeToolResult {
1069					warnings = append(warnings, ai.CallWarning{
1070						Type:    ai.CallWarningTypeOther,
1071						Message: "tool message can only have tool result content",
1072					})
1073					continue
1074				}
1075
1076				toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1077				if !ok {
1078					warnings = append(warnings, ai.CallWarning{
1079						Type:    ai.CallWarningTypeOther,
1080						Message: "tool message result part does not have the right type",
1081					})
1082					continue
1083				}
1084
1085				switch toolResultPart.Output.GetType() {
1086				case ai.ToolResultContentTypeText:
1087					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1088					if !ok {
1089						warnings = append(warnings, ai.CallWarning{
1090							Type:    ai.CallWarningTypeOther,
1091							Message: "tool result output does not have the right type",
1092						})
1093						continue
1094					}
1095					messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1096				case ai.ToolResultContentTypeError:
1097					// TODO: check if better handling is needed
1098					output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1099					if !ok {
1100						warnings = append(warnings, ai.CallWarning{
1101							Type:    ai.CallWarningTypeOther,
1102							Message: "tool result output does not have the right type",
1103						})
1104						continue
1105					}
1106					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1107				}
1108			}
1109		}
1110	}
1111	return messages, warnings
1112}
1113
1114// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1115func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1116	var annotations []openai.ChatCompletionMessageAnnotation
1117
1118	// Parse the raw JSON to extract annotations
1119	var deltaData map[string]any
1120	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1121		return annotations
1122	}
1123
1124	// Check if annotations exist in the delta
1125	if annotationsData, ok := deltaData["annotations"].([]any); ok {
1126		for _, annotationData := range annotationsData {
1127			if annotationMap, ok := annotationData.(map[string]any); ok {
1128				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1129					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1130						annotation := openai.ChatCompletionMessageAnnotation{
1131							Type: "url_citation",
1132							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1133								URL:   urlCitationData["url"].(string),
1134								Title: urlCitationData["title"].(string),
1135							},
1136						}
1137						annotations = append(annotations, annotation)
1138					}
1139				}
1140			}
1141		}
1142	}
1143
1144	return annotations
1145}