1package openai
2
3import (
4 "fmt"
5
6 "github.com/charmbracelet/fantasy/ai"
7 "github.com/google/uuid"
8 "github.com/openai/openai-go/v2"
9 "github.com/openai/openai-go/v2/packages/param"
10 "github.com/openai/openai-go/v2/shared"
11)
12
13type (
14 LanguageModelGenerateIDFunc = func() string
15 LanguageModelPrepareCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
16 LanguageModelMapFinishReasonFunc = func(choice openai.ChatCompletionChoice) ai.FinishReason
17 LanguageModelUsageFunc = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
18 LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []ai.Content
19 LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
20 LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
21 LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
22)
23
24func DefaultGenerateID() string {
25 return uuid.NewString()
26}
27
28func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
29 if call.ProviderOptions == nil {
30 return nil, nil
31 }
32 var warnings []ai.CallWarning
33 providerOptions := &ProviderOptions{}
34 if v, ok := call.ProviderOptions[Name]; ok {
35 providerOptions, ok = v.(*ProviderOptions)
36 if !ok {
37 return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
38 }
39 }
40
41 if providerOptions.LogitBias != nil {
42 params.LogitBias = providerOptions.LogitBias
43 }
44 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
45 providerOptions.LogProbs = nil
46 }
47 if providerOptions.LogProbs != nil {
48 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
49 }
50 if providerOptions.TopLogProbs != nil {
51 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
52 }
53 if providerOptions.User != nil {
54 params.User = param.NewOpt(*providerOptions.User)
55 }
56 if providerOptions.ParallelToolCalls != nil {
57 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
58 }
59 if providerOptions.MaxCompletionTokens != nil {
60 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
61 }
62
63 if providerOptions.TextVerbosity != nil {
64 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
65 }
66 if providerOptions.Prediction != nil {
67 // Convert map[string]any to ChatCompletionPredictionContentParam
68 if content, ok := providerOptions.Prediction["content"]; ok {
69 if contentStr, ok := content.(string); ok {
70 params.Prediction = openai.ChatCompletionPredictionContentParam{
71 Content: openai.ChatCompletionPredictionContentContentUnionParam{
72 OfString: param.NewOpt(contentStr),
73 },
74 }
75 }
76 }
77 }
78 if providerOptions.Store != nil {
79 params.Store = param.NewOpt(*providerOptions.Store)
80 }
81 if providerOptions.Metadata != nil {
82 // Convert map[string]any to map[string]string
83 metadata := make(map[string]string)
84 for k, v := range providerOptions.Metadata {
85 if str, ok := v.(string); ok {
86 metadata[k] = str
87 }
88 }
89 params.Metadata = metadata
90 }
91 if providerOptions.PromptCacheKey != nil {
92 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
93 }
94 if providerOptions.SafetyIdentifier != nil {
95 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
96 }
97 if providerOptions.ServiceTier != nil {
98 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
99 }
100
101 if providerOptions.ReasoningEffort != nil {
102 switch *providerOptions.ReasoningEffort {
103 case ReasoningEffortMinimal:
104 params.ReasoningEffort = shared.ReasoningEffortMinimal
105 case ReasoningEffortLow:
106 params.ReasoningEffort = shared.ReasoningEffortLow
107 case ReasoningEffortMedium:
108 params.ReasoningEffort = shared.ReasoningEffortMedium
109 case ReasoningEffortHigh:
110 params.ReasoningEffort = shared.ReasoningEffortHigh
111 default:
112 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
113 }
114 }
115
116 if isReasoningModel(model.Model()) {
117 if providerOptions.LogitBias != nil {
118 params.LogitBias = nil
119 warnings = append(warnings, ai.CallWarning{
120 Type: ai.CallWarningTypeUnsupportedSetting,
121 Setting: "LogitBias",
122 Message: "LogitBias is not supported for reasoning models",
123 })
124 }
125 if providerOptions.LogProbs != nil {
126 params.Logprobs = param.Opt[bool]{}
127 warnings = append(warnings, ai.CallWarning{
128 Type: ai.CallWarningTypeUnsupportedSetting,
129 Setting: "Logprobs",
130 Message: "Logprobs is not supported for reasoning models",
131 })
132 }
133 if providerOptions.TopLogProbs != nil {
134 params.TopLogprobs = param.Opt[int64]{}
135 warnings = append(warnings, ai.CallWarning{
136 Type: ai.CallWarningTypeUnsupportedSetting,
137 Setting: "TopLogprobs",
138 Message: "TopLogprobs is not supported for reasoning models",
139 })
140 }
141 }
142
143 // Handle service tier validation
144 if providerOptions.ServiceTier != nil {
145 serviceTier := *providerOptions.ServiceTier
146 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
147 params.ServiceTier = ""
148 warnings = append(warnings, ai.CallWarning{
149 Type: ai.CallWarningTypeUnsupportedSetting,
150 Setting: "ServiceTier",
151 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
152 })
153 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
154 params.ServiceTier = ""
155 warnings = append(warnings, ai.CallWarning{
156 Type: ai.CallWarningTypeUnsupportedSetting,
157 Setting: "ServiceTier",
158 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
159 })
160 }
161 }
162 return warnings, nil
163}
164
165func DefaultMapFinishReasonFunc(choice openai.ChatCompletionChoice) ai.FinishReason {
166 finishReason := choice.FinishReason
167 switch finishReason {
168 case "stop":
169 return ai.FinishReasonStop
170 case "length":
171 return ai.FinishReasonLength
172 case "content_filter":
173 return ai.FinishReasonContentFilter
174 case "function_call", "tool_calls":
175 return ai.FinishReasonToolCalls
176 default:
177 return ai.FinishReasonUnknown
178 }
179}
180
181func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
182 completionTokenDetails := response.Usage.CompletionTokensDetails
183 promptTokenDetails := response.Usage.PromptTokensDetails
184
185 // Build provider metadata
186 providerMetadata := &ProviderMetadata{}
187
188 // Add logprobs if available
189 if len(response.Choices) > 0 && len(response.Choices[0].Logprobs.Content) > 0 {
190 providerMetadata.Logprobs = response.Choices[0].Logprobs.Content
191 }
192
193 // Add prediction tokens if available
194 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
195 if completionTokenDetails.AcceptedPredictionTokens > 0 {
196 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
197 }
198 if completionTokenDetails.RejectedPredictionTokens > 0 {
199 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
200 }
201 }
202 return ai.Usage{
203 InputTokens: response.Usage.PromptTokens,
204 OutputTokens: response.Usage.CompletionTokens,
205 TotalTokens: response.Usage.TotalTokens,
206 ReasoningTokens: completionTokenDetails.ReasoningTokens,
207 CacheReadTokens: promptTokenDetails.CachedTokens,
208 }, providerMetadata
209}
210
211func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
212 if chunk.Usage.TotalTokens == 0 {
213 return ai.Usage{}, nil
214 }
215 streamProviderMetadata := &ProviderMetadata{}
216 if metadata != nil {
217 if providerMetadata, ok := metadata[Name]; ok {
218 converted, ok := providerMetadata.(*ProviderMetadata)
219 if ok {
220 streamProviderMetadata = converted
221 }
222 }
223 }
224 // we do this here because the acc does not add prompt details
225 completionTokenDetails := chunk.Usage.CompletionTokensDetails
226 promptTokenDetails := chunk.Usage.PromptTokensDetails
227 usage := ai.Usage{
228 InputTokens: chunk.Usage.PromptTokens,
229 OutputTokens: chunk.Usage.CompletionTokens,
230 TotalTokens: chunk.Usage.TotalTokens,
231 ReasoningTokens: completionTokenDetails.ReasoningTokens,
232 CacheReadTokens: promptTokenDetails.CachedTokens,
233 }
234
235 // Add prediction tokens if available
236 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
237 if completionTokenDetails.AcceptedPredictionTokens > 0 {
238 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
239 }
240 if completionTokenDetails.RejectedPredictionTokens > 0 {
241 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
242 }
243 }
244
245 return usage, ai.ProviderMetadata{
246 Name: streamProviderMetadata,
247 }
248}
249
250func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
251 streamProviderMetadata, ok := metadata[Name]
252 if !ok {
253 streamProviderMetadata = &ProviderMetadata{}
254 }
255 if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
256 converted.Logprobs = choice.Logprobs.Content
257 metadata[Name] = converted
258 }
259 return metadata
260}