1package openai
2
3import (
4 "fmt"
5
6 "github.com/charmbracelet/fantasy/ai"
7 "github.com/google/uuid"
8 "github.com/openai/openai-go/v2"
9 "github.com/openai/openai-go/v2/packages/param"
10 "github.com/openai/openai-go/v2/shared"
11)
12
13type (
14 LanguageModelGenerateIDFunc = func() string
15 LanguageModelPrepareCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
16 LanguageModelMapFinishReasonFunc = func(choice openai.ChatCompletionChoice) ai.FinishReason
17 LanguageModelUsageFunc = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
18 LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []ai.Content
19 LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
20 LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
21 LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
22)
23
24func DefaultGenerateID() string {
25 return uuid.NewString()
26}
27
28func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
29 if call.ProviderOptions == nil {
30 return nil, nil
31 }
32 var warnings []ai.CallWarning
33 providerOptions := &ProviderOptions{}
34 if v, ok := call.ProviderOptions[Name]; ok {
35 providerOptions, ok = v.(*ProviderOptions)
36 if !ok {
37 return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
38 }
39 }
40
41 if providerOptions.LogitBias != nil {
42 params.LogitBias = providerOptions.LogitBias
43 }
44 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
45 providerOptions.LogProbs = nil
46 }
47 if providerOptions.LogProbs != nil {
48 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
49 }
50 if providerOptions.TopLogProbs != nil {
51 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
52 }
53 if providerOptions.User != nil {
54 params.User = param.NewOpt(*providerOptions.User)
55 }
56 if providerOptions.ParallelToolCalls != nil {
57 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
58 }
59 if providerOptions.MaxCompletionTokens != nil {
60 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
61 }
62
63 if providerOptions.TextVerbosity != nil {
64 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
65 }
66 if providerOptions.Prediction != nil {
67 // Convert map[string]any to ChatCompletionPredictionContentParam
68 if content, ok := providerOptions.Prediction["content"]; ok {
69 if contentStr, ok := content.(string); ok {
70 params.Prediction = openai.ChatCompletionPredictionContentParam{
71 Content: openai.ChatCompletionPredictionContentContentUnionParam{
72 OfString: param.NewOpt(contentStr),
73 },
74 }
75 }
76 }
77 }
78 if providerOptions.Store != nil {
79 params.Store = param.NewOpt(*providerOptions.Store)
80 }
81 if providerOptions.Metadata != nil {
82 // Convert map[string]any to map[string]string
83 metadata := make(map[string]string)
84 for k, v := range providerOptions.Metadata {
85 if str, ok := v.(string); ok {
86 metadata[k] = str
87 }
88 }
89 params.Metadata = metadata
90 }
91 if providerOptions.PromptCacheKey != nil {
92 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
93 }
94 if providerOptions.SafetyIdentifier != nil {
95 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
96 }
97 if providerOptions.ServiceTier != nil {
98 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
99 }
100
101 if providerOptions.ReasoningEffort != nil {
102 switch *providerOptions.ReasoningEffort {
103 case ReasoningEffortMinimal:
104 params.ReasoningEffort = shared.ReasoningEffortMinimal
105 case ReasoningEffortLow:
106 params.ReasoningEffort = shared.ReasoningEffortLow
107 case ReasoningEffortMedium:
108 params.ReasoningEffort = shared.ReasoningEffortMedium
109 case ReasoningEffortHigh:
110 params.ReasoningEffort = shared.ReasoningEffortHigh
111 default:
112 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
113 }
114 }
115
116 if isReasoningModel(model.Model()) {
117 if providerOptions.LogitBias != nil {
118 params.LogitBias = nil
119 warnings = append(warnings, ai.CallWarning{
120 Type: ai.CallWarningTypeUnsupportedSetting,
121 Setting: "LogitBias",
122 Message: "LogitBias is not supported for reasoning models",
123 })
124 }
125 if providerOptions.LogProbs != nil {
126 params.Logprobs = param.Opt[bool]{}
127 warnings = append(warnings, ai.CallWarning{
128 Type: ai.CallWarningTypeUnsupportedSetting,
129 Setting: "Logprobs",
130 Message: "Logprobs is not supported for reasoning models",
131 })
132 }
133 if providerOptions.TopLogProbs != nil {
134 params.TopLogprobs = param.Opt[int64]{}
135 warnings = append(warnings, ai.CallWarning{
136 Type: ai.CallWarningTypeUnsupportedSetting,
137 Setting: "TopLogprobs",
138 Message: "TopLogprobs is not supported for reasoning models",
139 })
140 }
141 }
142
143 // Handle service tier validation
144 if providerOptions.ServiceTier != nil {
145 serviceTier := *providerOptions.ServiceTier
146 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
147 params.ServiceTier = ""
148 warnings = append(warnings, ai.CallWarning{
149 Type: ai.CallWarningTypeUnsupportedSetting,
150 Setting: "ServiceTier",
151 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
152 })
153 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
154 params.ServiceTier = ""
155 warnings = append(warnings, ai.CallWarning{
156 Type: ai.CallWarningTypeUnsupportedSetting,
157 Setting: "ServiceTier",
158 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
159 })
160 }
161 }
162 return warnings, nil
163}
164
165func DefaultMapFinishReasonFunc(choice openai.ChatCompletionChoice) ai.FinishReason {
166 finishReason := choice.FinishReason
167 switch finishReason {
168 case "stop":
169 return ai.FinishReasonStop
170 case "length":
171 return ai.FinishReasonLength
172 case "content_filter":
173 return ai.FinishReasonContentFilter
174 case "function_call", "tool_calls":
175 return ai.FinishReasonToolCalls
176 default:
177 return ai.FinishReasonUnknown
178 }
179}
180
181func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
182 if len(response.Choices) == 0 {
183 return ai.Usage{}, nil
184 }
185 choice := response.Choices[0]
186 completionTokenDetails := response.Usage.CompletionTokensDetails
187 promptTokenDetails := response.Usage.PromptTokensDetails
188
189 // Build provider metadata
190 providerMetadata := &ProviderMetadata{}
191 // Add logprobs if available
192 if len(choice.Logprobs.Content) > 0 {
193 providerMetadata.Logprobs = choice.Logprobs.Content
194 }
195
196 // Add prediction tokens if available
197 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
198 if completionTokenDetails.AcceptedPredictionTokens > 0 {
199 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
200 }
201 if completionTokenDetails.RejectedPredictionTokens > 0 {
202 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
203 }
204 }
205 return ai.Usage{
206 InputTokens: response.Usage.PromptTokens,
207 OutputTokens: response.Usage.CompletionTokens,
208 TotalTokens: response.Usage.TotalTokens,
209 ReasoningTokens: completionTokenDetails.ReasoningTokens,
210 CacheReadTokens: promptTokenDetails.CachedTokens,
211 }, providerMetadata
212}
213
214func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
215 if chunk.Usage.TotalTokens == 0 {
216 return ai.Usage{}, nil
217 }
218 streamProviderMetadata := &ProviderMetadata{}
219 if metadata != nil {
220 if providerMetadata, ok := metadata[Name]; ok {
221 converted, ok := providerMetadata.(*ProviderMetadata)
222 if ok {
223 streamProviderMetadata = converted
224 }
225 }
226 }
227 // we do this here because the acc does not add prompt details
228 completionTokenDetails := chunk.Usage.CompletionTokensDetails
229 promptTokenDetails := chunk.Usage.PromptTokensDetails
230 usage := ai.Usage{
231 InputTokens: chunk.Usage.PromptTokens,
232 OutputTokens: chunk.Usage.CompletionTokens,
233 TotalTokens: chunk.Usage.TotalTokens,
234 ReasoningTokens: completionTokenDetails.ReasoningTokens,
235 CacheReadTokens: promptTokenDetails.CachedTokens,
236 }
237
238 // Add prediction tokens if available
239 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
240 if completionTokenDetails.AcceptedPredictionTokens > 0 {
241 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
242 }
243 if completionTokenDetails.RejectedPredictionTokens > 0 {
244 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
245 }
246 }
247
248 return usage, ai.ProviderMetadata{
249 Name: streamProviderMetadata,
250 }
251}
252
253func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
254 streamProviderMetadata, ok := metadata[Name]
255 if !ok {
256 streamProviderMetadata = &ProviderMetadata{}
257 }
258 if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
259 converted.Logprobs = choice.Logprobs.Content
260 metadata[Name] = converted
261 }
262 return metadata
263}