1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/option"
19 "github.com/openai/openai-go/v2/packages/param"
20 "github.com/openai/openai-go/v2/shared"
21)
22
23const (
24 ProviderName = "openai"
25 DefaultURL = "https://api.openai.com/v1"
26)
27
28type provider struct {
29 options options
30}
31
32type options struct {
33 baseURL string
34 apiKey string
35 organization string
36 project string
37 name string
38 headers map[string]string
39 client option.HTTPClient
40}
41
42type Option = func(*options)
43
44func New(opts ...Option) ai.Provider {
45 providerOptions := options{
46 headers: map[string]string{},
47 }
48 for _, o := range opts {
49 o(&providerOptions)
50 }
51
52 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
53 providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
54
55 if providerOptions.organization != "" {
56 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
57 }
58 if providerOptions.project != "" {
59 providerOptions.headers["OpenAi-Project"] = providerOptions.project
60 }
61
62 return &provider{options: providerOptions}
63}
64
65func WithBaseURL(baseURL string) Option {
66 return func(o *options) {
67 o.baseURL = baseURL
68 }
69}
70
71func WithAPIKey(apiKey string) Option {
72 return func(o *options) {
73 o.apiKey = apiKey
74 }
75}
76
77func WithOrganization(organization string) Option {
78 return func(o *options) {
79 o.organization = organization
80 }
81}
82
83func WithProject(project string) Option {
84 return func(o *options) {
85 o.project = project
86 }
87}
88
89func WithName(name string) Option {
90 return func(o *options) {
91 o.name = name
92 }
93}
94
95func WithHeaders(headers map[string]string) Option {
96 return func(o *options) {
97 maps.Copy(o.headers, headers)
98 }
99}
100
101func WithHTTPClient(client option.HTTPClient) Option {
102 return func(o *options) {
103 o.client = client
104 }
105}
106
107// LanguageModel implements ai.Provider.
108func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
109 openaiClientOptions := []option.RequestOption{}
110 if o.options.apiKey != "" {
111 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
112 }
113 if o.options.baseURL != "" {
114 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
115 }
116
117 for key, value := range o.options.headers {
118 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
119 }
120
121 if o.options.client != nil {
122 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
123 }
124
125 return languageModel{
126 modelID: modelID,
127 provider: fmt.Sprintf("%s.chat", o.options.name),
128 options: o.options,
129 client: openai.NewClient(openaiClientOptions...),
130 }, nil
131}
132
133type languageModel struct {
134 provider string
135 modelID string
136 client openai.Client
137 options options
138}
139
140// Model implements ai.LanguageModel.
141func (o languageModel) Model() string {
142 return o.modelID
143}
144
145// Provider implements ai.LanguageModel.
146func (o languageModel) Provider() string {
147 return o.provider
148}
149
150func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
151 params := &openai.ChatCompletionNewParams{}
152 messages, warnings := toPrompt(call.Prompt)
153 providerOptions := &ProviderOptions{}
154 if v, ok := call.ProviderOptions[OptionsKey]; ok {
155 providerOptions, ok = v.(*ProviderOptions)
156 if !ok {
157 return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
158 }
159 }
160 if call.TopK != nil {
161 warnings = append(warnings, ai.CallWarning{
162 Type: ai.CallWarningTypeUnsupportedSetting,
163 Setting: "top_k",
164 })
165 }
166 params.Messages = messages
167 params.Model = o.modelID
168 if providerOptions.LogitBias != nil {
169 params.LogitBias = providerOptions.LogitBias
170 }
171 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
172 providerOptions.LogProbs = nil
173 }
174 if providerOptions.LogProbs != nil {
175 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
176 }
177 if providerOptions.TopLogProbs != nil {
178 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
179 }
180 if providerOptions.User != nil {
181 params.User = param.NewOpt(*providerOptions.User)
182 }
183 if providerOptions.ParallelToolCalls != nil {
184 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
185 }
186
187 if call.MaxOutputTokens != nil {
188 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
189 }
190 if call.Temperature != nil {
191 params.Temperature = param.NewOpt(*call.Temperature)
192 }
193 if call.TopP != nil {
194 params.TopP = param.NewOpt(*call.TopP)
195 }
196 if call.FrequencyPenalty != nil {
197 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
198 }
199 if call.PresencePenalty != nil {
200 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
201 }
202
203 if providerOptions.MaxCompletionTokens != nil {
204 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
205 }
206
207 if providerOptions.TextVerbosity != nil {
208 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
209 }
210 if providerOptions.Prediction != nil {
211 // Convert map[string]any to ChatCompletionPredictionContentParam
212 if content, ok := providerOptions.Prediction["content"]; ok {
213 if contentStr, ok := content.(string); ok {
214 params.Prediction = openai.ChatCompletionPredictionContentParam{
215 Content: openai.ChatCompletionPredictionContentContentUnionParam{
216 OfString: param.NewOpt(contentStr),
217 },
218 }
219 }
220 }
221 }
222 if providerOptions.Store != nil {
223 params.Store = param.NewOpt(*providerOptions.Store)
224 }
225 if providerOptions.Metadata != nil {
226 // Convert map[string]any to map[string]string
227 metadata := make(map[string]string)
228 for k, v := range providerOptions.Metadata {
229 if str, ok := v.(string); ok {
230 metadata[k] = str
231 }
232 }
233 params.Metadata = metadata
234 }
235 if providerOptions.PromptCacheKey != nil {
236 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
237 }
238 if providerOptions.SafetyIdentifier != nil {
239 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
240 }
241 if providerOptions.ServiceTier != nil {
242 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
243 }
244
245 if providerOptions.ReasoningEffort != nil {
246 switch *providerOptions.ReasoningEffort {
247 case ReasoningEffortMinimal:
248 params.ReasoningEffort = shared.ReasoningEffortMinimal
249 case ReasoningEffortLow:
250 params.ReasoningEffort = shared.ReasoningEffortLow
251 case ReasoningEffortMedium:
252 params.ReasoningEffort = shared.ReasoningEffortMedium
253 case ReasoningEffortHigh:
254 params.ReasoningEffort = shared.ReasoningEffortHigh
255 default:
256 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
257 }
258 }
259
260 if isReasoningModel(o.modelID) {
261 // remove unsupported settings for reasoning models
262 // see https://platform.openai.com/docs/guides/reasoning#limitations
263 if call.Temperature != nil {
264 params.Temperature = param.Opt[float64]{}
265 warnings = append(warnings, ai.CallWarning{
266 Type: ai.CallWarningTypeUnsupportedSetting,
267 Setting: "temperature",
268 Details: "temperature is not supported for reasoning models",
269 })
270 }
271 if call.TopP != nil {
272 params.TopP = param.Opt[float64]{}
273 warnings = append(warnings, ai.CallWarning{
274 Type: ai.CallWarningTypeUnsupportedSetting,
275 Setting: "TopP",
276 Details: "TopP is not supported for reasoning models",
277 })
278 }
279 if call.FrequencyPenalty != nil {
280 params.FrequencyPenalty = param.Opt[float64]{}
281 warnings = append(warnings, ai.CallWarning{
282 Type: ai.CallWarningTypeUnsupportedSetting,
283 Setting: "FrequencyPenalty",
284 Details: "FrequencyPenalty is not supported for reasoning models",
285 })
286 }
287 if call.PresencePenalty != nil {
288 params.PresencePenalty = param.Opt[float64]{}
289 warnings = append(warnings, ai.CallWarning{
290 Type: ai.CallWarningTypeUnsupportedSetting,
291 Setting: "PresencePenalty",
292 Details: "PresencePenalty is not supported for reasoning models",
293 })
294 }
295 if providerOptions.LogitBias != nil {
296 params.LogitBias = nil
297 warnings = append(warnings, ai.CallWarning{
298 Type: ai.CallWarningTypeUnsupportedSetting,
299 Setting: "LogitBias",
300 Message: "LogitBias is not supported for reasoning models",
301 })
302 }
303 if providerOptions.LogProbs != nil {
304 params.Logprobs = param.Opt[bool]{}
305 warnings = append(warnings, ai.CallWarning{
306 Type: ai.CallWarningTypeUnsupportedSetting,
307 Setting: "Logprobs",
308 Message: "Logprobs is not supported for reasoning models",
309 })
310 }
311 if providerOptions.TopLogProbs != nil {
312 params.TopLogprobs = param.Opt[int64]{}
313 warnings = append(warnings, ai.CallWarning{
314 Type: ai.CallWarningTypeUnsupportedSetting,
315 Setting: "TopLogprobs",
316 Message: "TopLogprobs is not supported for reasoning models",
317 })
318 }
319
320 // reasoning models use max_completion_tokens instead of max_tokens
321 if call.MaxOutputTokens != nil {
322 if providerOptions.MaxCompletionTokens == nil {
323 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
324 }
325 params.MaxTokens = param.Opt[int64]{}
326 }
327 }
328
329 // Handle search preview models
330 if isSearchPreviewModel(o.modelID) {
331 if call.Temperature != nil {
332 params.Temperature = param.Opt[float64]{}
333 warnings = append(warnings, ai.CallWarning{
334 Type: ai.CallWarningTypeUnsupportedSetting,
335 Setting: "temperature",
336 Details: "temperature is not supported for the search preview models and has been removed.",
337 })
338 }
339 }
340
341 // Handle service tier validation
342 if providerOptions.ServiceTier != nil {
343 serviceTier := *providerOptions.ServiceTier
344 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
345 params.ServiceTier = ""
346 warnings = append(warnings, ai.CallWarning{
347 Type: ai.CallWarningTypeUnsupportedSetting,
348 Setting: "ServiceTier",
349 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
350 })
351 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
352 params.ServiceTier = ""
353 warnings = append(warnings, ai.CallWarning{
354 Type: ai.CallWarningTypeUnsupportedSetting,
355 Setting: "ServiceTier",
356 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
357 })
358 }
359 }
360
361 if len(call.Tools) > 0 {
362 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
363 params.Tools = tools
364 if toolChoice != nil {
365 params.ToolChoice = *toolChoice
366 }
367 warnings = append(warnings, toolWarnings...)
368 }
369 return params, warnings, nil
370}
371
372func (o languageModel) handleError(err error) error {
373 var apiErr *openai.Error
374 if errors.As(err, &apiErr) {
375 requestDump := apiErr.DumpRequest(true)
376 responseDump := apiErr.DumpResponse(true)
377 headers := map[string]string{}
378 for k, h := range apiErr.Response.Header {
379 v := h[len(h)-1]
380 headers[strings.ToLower(k)] = v
381 }
382 return ai.NewAPICallError(
383 apiErr.Message,
384 apiErr.Request.URL.String(),
385 string(requestDump),
386 apiErr.StatusCode,
387 headers,
388 string(responseDump),
389 apiErr,
390 false,
391 )
392 }
393 return err
394}
395
396// Generate implements ai.LanguageModel.
397func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
398 params, warnings, err := o.prepareParams(call)
399 if err != nil {
400 return nil, err
401 }
402 response, err := o.client.Chat.Completions.New(ctx, *params)
403 if err != nil {
404 return nil, o.handleError(err)
405 }
406
407 if len(response.Choices) == 0 {
408 return nil, errors.New("no response generated")
409 }
410 choice := response.Choices[0]
411 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
412 text := choice.Message.Content
413 if text != "" {
414 content = append(content, ai.TextContent{
415 Text: text,
416 })
417 }
418
419 for _, tc := range choice.Message.ToolCalls {
420 toolCallID := tc.ID
421 if toolCallID == "" {
422 toolCallID = uuid.NewString()
423 }
424 content = append(content, ai.ToolCallContent{
425 ProviderExecuted: false, // TODO: update when handling other tools
426 ToolCallID: toolCallID,
427 ToolName: tc.Function.Name,
428 Input: tc.Function.Arguments,
429 })
430 }
431 // Handle annotations/citations
432 for _, annotation := range choice.Message.Annotations {
433 if annotation.Type == "url_citation" {
434 content = append(content, ai.SourceContent{
435 SourceType: ai.SourceTypeURL,
436 ID: uuid.NewString(),
437 URL: annotation.URLCitation.URL,
438 Title: annotation.URLCitation.Title,
439 })
440 }
441 }
442
443 completionTokenDetails := response.Usage.CompletionTokensDetails
444 promptTokenDetails := response.Usage.PromptTokensDetails
445
446 // Build provider metadata
447 providerMetadata := &ProviderMetadata{}
448 // Add logprobs if available
449 if len(choice.Logprobs.Content) > 0 {
450 providerMetadata.Logprobs = choice.Logprobs.Content
451 }
452
453 // Add prediction tokens if available
454 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
455 if completionTokenDetails.AcceptedPredictionTokens > 0 {
456 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
457 }
458 if completionTokenDetails.RejectedPredictionTokens > 0 {
459 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
460 }
461 }
462
463 return &ai.Response{
464 Content: content,
465 Usage: ai.Usage{
466 InputTokens: response.Usage.PromptTokens,
467 OutputTokens: response.Usage.CompletionTokens,
468 TotalTokens: response.Usage.TotalTokens,
469 ReasoningTokens: completionTokenDetails.ReasoningTokens,
470 CacheReadTokens: promptTokenDetails.CachedTokens,
471 },
472 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
473 ProviderMetadata: ai.ProviderMetadata{
474 OptionsKey: providerMetadata,
475 },
476 Warnings: warnings,
477 }, nil
478}
479
480type toolCall struct {
481 id string
482 name string
483 arguments string
484 hasFinished bool
485}
486
487// Stream implements ai.LanguageModel.
488func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
489 params, warnings, err := o.prepareParams(call)
490 if err != nil {
491 return nil, err
492 }
493
494 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
495 IncludeUsage: openai.Bool(true),
496 }
497
498 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
499 isActiveText := false
500 toolCalls := make(map[int64]toolCall)
501
502 // Build provider metadata for streaming
503 streamProviderMetadata := &ProviderMetadata{}
504 acc := openai.ChatCompletionAccumulator{}
505 var usage ai.Usage
506 return func(yield func(ai.StreamPart) bool) {
507 if len(warnings) > 0 {
508 if !yield(ai.StreamPart{
509 Type: ai.StreamPartTypeWarnings,
510 Warnings: warnings,
511 }) {
512 return
513 }
514 }
515 for stream.Next() {
516 chunk := stream.Current()
517 acc.AddChunk(chunk)
518 if chunk.Usage.TotalTokens > 0 {
519 // we do this here because the acc does not add prompt details
520 completionTokenDetails := chunk.Usage.CompletionTokensDetails
521 promptTokenDetails := chunk.Usage.PromptTokensDetails
522 usage = ai.Usage{
523 InputTokens: chunk.Usage.PromptTokens,
524 OutputTokens: chunk.Usage.CompletionTokens,
525 TotalTokens: chunk.Usage.TotalTokens,
526 ReasoningTokens: completionTokenDetails.ReasoningTokens,
527 CacheReadTokens: promptTokenDetails.CachedTokens,
528 }
529
530 // Add prediction tokens if available
531 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
532 if completionTokenDetails.AcceptedPredictionTokens > 0 {
533 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
534 }
535 if completionTokenDetails.RejectedPredictionTokens > 0 {
536 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
537 }
538 }
539 }
540 if len(chunk.Choices) == 0 {
541 continue
542 }
543 for _, choice := range chunk.Choices {
544 switch {
545 case choice.Delta.Content != "":
546 if !isActiveText {
547 isActiveText = true
548 if !yield(ai.StreamPart{
549 Type: ai.StreamPartTypeTextStart,
550 ID: "0",
551 }) {
552 return
553 }
554 }
555 if !yield(ai.StreamPart{
556 Type: ai.StreamPartTypeTextDelta,
557 ID: "0",
558 Delta: choice.Delta.Content,
559 }) {
560 return
561 }
562 case len(choice.Delta.ToolCalls) > 0:
563 if isActiveText {
564 isActiveText = false
565 if !yield(ai.StreamPart{
566 Type: ai.StreamPartTypeTextEnd,
567 ID: "0",
568 }) {
569 return
570 }
571 }
572
573 for _, toolCallDelta := range choice.Delta.ToolCalls {
574 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
575 if existingToolCall.hasFinished {
576 continue
577 }
578 if toolCallDelta.Function.Arguments != "" {
579 existingToolCall.arguments += toolCallDelta.Function.Arguments
580 }
581 if !yield(ai.StreamPart{
582 Type: ai.StreamPartTypeToolInputDelta,
583 ID: existingToolCall.id,
584 Delta: toolCallDelta.Function.Arguments,
585 }) {
586 return
587 }
588 toolCalls[toolCallDelta.Index] = existingToolCall
589 if xjson.IsValid(existingToolCall.arguments) {
590 if !yield(ai.StreamPart{
591 Type: ai.StreamPartTypeToolInputEnd,
592 ID: existingToolCall.id,
593 }) {
594 return
595 }
596
597 if !yield(ai.StreamPart{
598 Type: ai.StreamPartTypeToolCall,
599 ID: existingToolCall.id,
600 ToolCallName: existingToolCall.name,
601 ToolCallInput: existingToolCall.arguments,
602 }) {
603 return
604 }
605 existingToolCall.hasFinished = true
606 toolCalls[toolCallDelta.Index] = existingToolCall
607 }
608 } else {
609 // Does not exist
610 var err error
611 if toolCallDelta.Type != "function" {
612 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
613 }
614 if toolCallDelta.ID == "" {
615 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
616 }
617 if toolCallDelta.Function.Name == "" {
618 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
619 }
620 if err != nil {
621 yield(ai.StreamPart{
622 Type: ai.StreamPartTypeError,
623 Error: o.handleError(stream.Err()),
624 })
625 return
626 }
627
628 if !yield(ai.StreamPart{
629 Type: ai.StreamPartTypeToolInputStart,
630 ID: toolCallDelta.ID,
631 ToolCallName: toolCallDelta.Function.Name,
632 }) {
633 return
634 }
635 toolCalls[toolCallDelta.Index] = toolCall{
636 id: toolCallDelta.ID,
637 name: toolCallDelta.Function.Name,
638 arguments: toolCallDelta.Function.Arguments,
639 }
640
641 exTc := toolCalls[toolCallDelta.Index]
642 if exTc.arguments != "" {
643 if !yield(ai.StreamPart{
644 Type: ai.StreamPartTypeToolInputDelta,
645 ID: exTc.id,
646 Delta: exTc.arguments,
647 }) {
648 return
649 }
650 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
651 if !yield(ai.StreamPart{
652 Type: ai.StreamPartTypeToolInputEnd,
653 ID: toolCallDelta.ID,
654 }) {
655 return
656 }
657
658 if !yield(ai.StreamPart{
659 Type: ai.StreamPartTypeToolCall,
660 ID: exTc.id,
661 ToolCallName: exTc.name,
662 ToolCallInput: exTc.arguments,
663 }) {
664 return
665 }
666 exTc.hasFinished = true
667 toolCalls[toolCallDelta.Index] = exTc
668 }
669 }
670 continue
671 }
672 }
673 }
674 }
675
676 // Check for annotations in the delta's raw JSON
677 for _, choice := range chunk.Choices {
678 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
679 for _, annotation := range annotations {
680 if annotation.Type == "url_citation" {
681 if !yield(ai.StreamPart{
682 Type: ai.StreamPartTypeSource,
683 ID: uuid.NewString(),
684 SourceType: ai.SourceTypeURL,
685 URL: annotation.URLCitation.URL,
686 Title: annotation.URLCitation.Title,
687 }) {
688 return
689 }
690 }
691 }
692 }
693 }
694 }
695 err := stream.Err()
696 if err == nil || errors.Is(err, io.EOF) {
697 // finished
698 if isActiveText {
699 isActiveText = false
700 if !yield(ai.StreamPart{
701 Type: ai.StreamPartTypeTextEnd,
702 ID: "0",
703 }) {
704 return
705 }
706 }
707
708 // Add logprobs if available
709 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
710 streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
711 }
712
713 // Handle annotations/citations from accumulated response
714 if len(acc.Choices) > 0 {
715 for _, annotation := range acc.Choices[0].Message.Annotations {
716 if annotation.Type == "url_citation" {
717 if !yield(ai.StreamPart{
718 Type: ai.StreamPartTypeSource,
719 ID: acc.ID,
720 SourceType: ai.SourceTypeURL,
721 URL: annotation.URLCitation.URL,
722 Title: annotation.URLCitation.Title,
723 }) {
724 return
725 }
726 }
727 }
728 }
729
730 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
731 yield(ai.StreamPart{
732 Type: ai.StreamPartTypeFinish,
733 Usage: usage,
734 FinishReason: finishReason,
735 ProviderMetadata: ai.ProviderMetadata{
736 OptionsKey: streamProviderMetadata,
737 },
738 })
739 return
740 } else {
741 yield(ai.StreamPart{
742 Type: ai.StreamPartTypeError,
743 Error: o.handleError(err),
744 })
745 return
746 }
747 }, nil
748}
749
750func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
751 switch finishReason {
752 case "stop":
753 return ai.FinishReasonStop
754 case "length":
755 return ai.FinishReasonLength
756 case "content_filter":
757 return ai.FinishReasonContentFilter
758 case "function_call", "tool_calls":
759 return ai.FinishReasonToolCalls
760 default:
761 return ai.FinishReasonUnknown
762 }
763}
764
765func isReasoningModel(modelID string) bool {
766 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
767}
768
769func isSearchPreviewModel(modelID string) bool {
770 return strings.Contains(modelID, "search-preview")
771}
772
773func supportsFlexProcessing(modelID string) bool {
774 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
775}
776
777func supportsPriorityProcessing(modelID string) bool {
778 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
779 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
780 strings.HasPrefix(modelID, "o4-mini")
781}
782
783func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
784 for _, tool := range tools {
785 if tool.GetType() == ai.ToolTypeFunction {
786 ft, ok := tool.(ai.FunctionTool)
787 if !ok {
788 continue
789 }
790 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
791 OfFunction: &openai.ChatCompletionFunctionToolParam{
792 Function: shared.FunctionDefinitionParam{
793 Name: ft.Name,
794 Description: param.NewOpt(ft.Description),
795 Parameters: openai.FunctionParameters(ft.InputSchema),
796 Strict: param.NewOpt(false),
797 },
798 Type: "function",
799 },
800 })
801 continue
802 }
803
804 // TODO: handle provider tool calls
805 warnings = append(warnings, ai.CallWarning{
806 Type: ai.CallWarningTypeUnsupportedTool,
807 Tool: tool,
808 Message: "tool is not supported",
809 })
810 }
811 if toolChoice == nil {
812 return openAiTools, openAiToolChoice, warnings
813 }
814
815 switch *toolChoice {
816 case ai.ToolChoiceAuto:
817 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
818 OfAuto: param.NewOpt("auto"),
819 }
820 case ai.ToolChoiceNone:
821 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
822 OfAuto: param.NewOpt("none"),
823 }
824 default:
825 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
826 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
827 Type: "function",
828 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
829 Name: string(*toolChoice),
830 },
831 },
832 }
833 }
834 return openAiTools, openAiToolChoice, warnings
835}
836
837func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
838 var messages []openai.ChatCompletionMessageParamUnion
839 var warnings []ai.CallWarning
840 for _, msg := range prompt {
841 switch msg.Role {
842 case ai.MessageRoleSystem:
843 var systemPromptParts []string
844 for _, c := range msg.Content {
845 if c.GetType() != ai.ContentTypeText {
846 warnings = append(warnings, ai.CallWarning{
847 Type: ai.CallWarningTypeOther,
848 Message: "system prompt can only have text content",
849 })
850 continue
851 }
852 textPart, ok := ai.AsContentType[ai.TextPart](c)
853 if !ok {
854 warnings = append(warnings, ai.CallWarning{
855 Type: ai.CallWarningTypeOther,
856 Message: "system prompt text part does not have the right type",
857 })
858 continue
859 }
860 text := textPart.Text
861 if strings.TrimSpace(text) != "" {
862 systemPromptParts = append(systemPromptParts, textPart.Text)
863 }
864 }
865 if len(systemPromptParts) == 0 {
866 warnings = append(warnings, ai.CallWarning{
867 Type: ai.CallWarningTypeOther,
868 Message: "system prompt has no text parts",
869 })
870 continue
871 }
872 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
873 case ai.MessageRoleUser:
874 // simple user message just text content
875 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
876 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
877 if !ok {
878 warnings = append(warnings, ai.CallWarning{
879 Type: ai.CallWarningTypeOther,
880 Message: "user message text part does not have the right type",
881 })
882 continue
883 }
884 messages = append(messages, openai.UserMessage(textPart.Text))
885 continue
886 }
887 // text content and attachments
888 // for now we only support image content later we need to check
889 // TODO: add the supported media types to the language model so we
890 // can use that to validate the data here.
891 var content []openai.ChatCompletionContentPartUnionParam
892 for _, c := range msg.Content {
893 switch c.GetType() {
894 case ai.ContentTypeText:
895 textPart, ok := ai.AsContentType[ai.TextPart](c)
896 if !ok {
897 warnings = append(warnings, ai.CallWarning{
898 Type: ai.CallWarningTypeOther,
899 Message: "user message text part does not have the right type",
900 })
901 continue
902 }
903 content = append(content, openai.ChatCompletionContentPartUnionParam{
904 OfText: &openai.ChatCompletionContentPartTextParam{
905 Text: textPart.Text,
906 },
907 })
908 case ai.ContentTypeFile:
909 filePart, ok := ai.AsContentType[ai.FilePart](c)
910 if !ok {
911 warnings = append(warnings, ai.CallWarning{
912 Type: ai.CallWarningTypeOther,
913 Message: "user message file part does not have the right type",
914 })
915 continue
916 }
917
918 switch {
919 case strings.HasPrefix(filePart.MediaType, "image/"):
920 // Handle image files
921 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
922 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
923 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
924
925 // Check for provider-specific options like image detail
926 if providerOptions, ok := filePart.ProviderOptions[OptionsKey]; ok {
927 if detail, ok := providerOptions.(*ProviderFileOptions); ok {
928 imageURL.Detail = detail.ImageDetail
929 }
930 }
931
932 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
933 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
934
935 case filePart.MediaType == "audio/wav":
936 // Handle WAV audio files
937 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
938 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
939 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
940 Data: base64Encoded,
941 Format: "wav",
942 },
943 }
944 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
945
946 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
947 // Handle MP3 audio files
948 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
949 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
950 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
951 Data: base64Encoded,
952 Format: "mp3",
953 },
954 }
955 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
956
957 case filePart.MediaType == "application/pdf":
958 // Handle PDF files
959 dataStr := string(filePart.Data)
960
961 // Check if data looks like a file ID (starts with "file-")
962 if strings.HasPrefix(dataStr, "file-") {
963 fileBlock := openai.ChatCompletionContentPartFileParam{
964 File: openai.ChatCompletionContentPartFileFileParam{
965 FileID: param.NewOpt(dataStr),
966 },
967 }
968 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
969 } else {
970 // Handle as base64 data
971 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
972 data := "data:application/pdf;base64," + base64Encoded
973
974 filename := filePart.Filename
975 if filename == "" {
976 // Generate default filename based on content index
977 filename = fmt.Sprintf("part-%d.pdf", len(content))
978 }
979
980 fileBlock := openai.ChatCompletionContentPartFileParam{
981 File: openai.ChatCompletionContentPartFileFileParam{
982 Filename: param.NewOpt(filename),
983 FileData: param.NewOpt(data),
984 },
985 }
986 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
987 }
988
989 default:
990 warnings = append(warnings, ai.CallWarning{
991 Type: ai.CallWarningTypeOther,
992 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
993 })
994 }
995 }
996 }
997 messages = append(messages, openai.UserMessage(content))
998 case ai.MessageRoleAssistant:
999 // simple assistant message just text content
1000 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1001 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1002 if !ok {
1003 warnings = append(warnings, ai.CallWarning{
1004 Type: ai.CallWarningTypeOther,
1005 Message: "assistant message text part does not have the right type",
1006 })
1007 continue
1008 }
1009 messages = append(messages, openai.AssistantMessage(textPart.Text))
1010 continue
1011 }
1012 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1013 Role: "assistant",
1014 }
1015 for _, c := range msg.Content {
1016 switch c.GetType() {
1017 case ai.ContentTypeText:
1018 textPart, ok := ai.AsContentType[ai.TextPart](c)
1019 if !ok {
1020 warnings = append(warnings, ai.CallWarning{
1021 Type: ai.CallWarningTypeOther,
1022 Message: "assistant message text part does not have the right type",
1023 })
1024 continue
1025 }
1026 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1027 OfString: param.NewOpt(textPart.Text),
1028 }
1029 case ai.ContentTypeToolCall:
1030 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1031 if !ok {
1032 warnings = append(warnings, ai.CallWarning{
1033 Type: ai.CallWarningTypeOther,
1034 Message: "assistant message tool part does not have the right type",
1035 })
1036 continue
1037 }
1038 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1039 openai.ChatCompletionMessageToolCallUnionParam{
1040 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1041 ID: toolCallPart.ToolCallID,
1042 Type: "function",
1043 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1044 Name: toolCallPart.ToolName,
1045 Arguments: toolCallPart.Input,
1046 },
1047 },
1048 })
1049 }
1050 }
1051 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1052 OfAssistant: &assistantMsg,
1053 })
1054 case ai.MessageRoleTool:
1055 for _, c := range msg.Content {
1056 if c.GetType() != ai.ContentTypeToolResult {
1057 warnings = append(warnings, ai.CallWarning{
1058 Type: ai.CallWarningTypeOther,
1059 Message: "tool message can only have tool result content",
1060 })
1061 continue
1062 }
1063
1064 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1065 if !ok {
1066 warnings = append(warnings, ai.CallWarning{
1067 Type: ai.CallWarningTypeOther,
1068 Message: "tool message result part does not have the right type",
1069 })
1070 continue
1071 }
1072
1073 switch toolResultPart.Output.GetType() {
1074 case ai.ToolResultContentTypeText:
1075 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1076 if !ok {
1077 warnings = append(warnings, ai.CallWarning{
1078 Type: ai.CallWarningTypeOther,
1079 Message: "tool result output does not have the right type",
1080 })
1081 continue
1082 }
1083 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1084 case ai.ToolResultContentTypeError:
1085 // TODO: check if better handling is needed
1086 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1087 if !ok {
1088 warnings = append(warnings, ai.CallWarning{
1089 Type: ai.CallWarningTypeOther,
1090 Message: "tool result output does not have the right type",
1091 })
1092 continue
1093 }
1094 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1095 }
1096 }
1097 }
1098 }
1099 return messages, warnings
1100}
1101
1102// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1103func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1104 var annotations []openai.ChatCompletionMessageAnnotation
1105
1106 // Parse the raw JSON to extract annotations
1107 var deltaData map[string]any
1108 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1109 return annotations
1110 }
1111
1112 // Check if annotations exist in the delta
1113 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1114 for _, annotationData := range annotationsData {
1115 if annotationMap, ok := annotationData.(map[string]any); ok {
1116 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1117 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1118 annotation := openai.ChatCompletionMessageAnnotation{
1119 Type: "url_citation",
1120 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1121 URL: urlCitationData["url"].(string),
1122 Title: urlCitationData["title"].(string),
1123 },
1124 }
1125 annotations = append(annotations, annotation)
1126 }
1127 }
1128 }
1129 }
1130 }
1131
1132 return annotations
1133}