1package openai
2
3import (
4 "fmt"
5
6 "github.com/charmbracelet/fantasy/ai"
7 "github.com/openai/openai-go/v2"
8 "github.com/openai/openai-go/v2/packages/param"
9 "github.com/openai/openai-go/v2/shared"
10)
11
12type PrepareLanguageModelCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
13
14func defaultPrepareLanguageModelCall(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
15 if call.ProviderOptions == nil {
16 return nil, nil
17 }
18 var warnings []ai.CallWarning
19 providerOptions := &ProviderOptions{}
20 if v, ok := call.ProviderOptions[Name]; ok {
21 providerOptions, ok = v.(*ProviderOptions)
22 if !ok {
23 return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
24 }
25 }
26
27 if providerOptions.LogitBias != nil {
28 params.LogitBias = providerOptions.LogitBias
29 }
30 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
31 providerOptions.LogProbs = nil
32 }
33 if providerOptions.LogProbs != nil {
34 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
35 }
36 if providerOptions.TopLogProbs != nil {
37 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
38 }
39 if providerOptions.User != nil {
40 params.User = param.NewOpt(*providerOptions.User)
41 }
42 if providerOptions.ParallelToolCalls != nil {
43 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
44 }
45 if providerOptions.MaxCompletionTokens != nil {
46 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
47 }
48
49 if providerOptions.TextVerbosity != nil {
50 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
51 }
52 if providerOptions.Prediction != nil {
53 // Convert map[string]any to ChatCompletionPredictionContentParam
54 if content, ok := providerOptions.Prediction["content"]; ok {
55 if contentStr, ok := content.(string); ok {
56 params.Prediction = openai.ChatCompletionPredictionContentParam{
57 Content: openai.ChatCompletionPredictionContentContentUnionParam{
58 OfString: param.NewOpt(contentStr),
59 },
60 }
61 }
62 }
63 }
64 if providerOptions.Store != nil {
65 params.Store = param.NewOpt(*providerOptions.Store)
66 }
67 if providerOptions.Metadata != nil {
68 // Convert map[string]any to map[string]string
69 metadata := make(map[string]string)
70 for k, v := range providerOptions.Metadata {
71 if str, ok := v.(string); ok {
72 metadata[k] = str
73 }
74 }
75 params.Metadata = metadata
76 }
77 if providerOptions.PromptCacheKey != nil {
78 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
79 }
80 if providerOptions.SafetyIdentifier != nil {
81 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
82 }
83 if providerOptions.ServiceTier != nil {
84 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
85 }
86
87 if providerOptions.ReasoningEffort != nil {
88 switch *providerOptions.ReasoningEffort {
89 case ReasoningEffortMinimal:
90 params.ReasoningEffort = shared.ReasoningEffortMinimal
91 case ReasoningEffortLow:
92 params.ReasoningEffort = shared.ReasoningEffortLow
93 case ReasoningEffortMedium:
94 params.ReasoningEffort = shared.ReasoningEffortMedium
95 case ReasoningEffortHigh:
96 params.ReasoningEffort = shared.ReasoningEffortHigh
97 default:
98 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
99 }
100 }
101
102 if isReasoningModel(model.Model()) {
103 if providerOptions.LogitBias != nil {
104 params.LogitBias = nil
105 warnings = append(warnings, ai.CallWarning{
106 Type: ai.CallWarningTypeUnsupportedSetting,
107 Setting: "LogitBias",
108 Message: "LogitBias is not supported for reasoning models",
109 })
110 }
111 if providerOptions.LogProbs != nil {
112 params.Logprobs = param.Opt[bool]{}
113 warnings = append(warnings, ai.CallWarning{
114 Type: ai.CallWarningTypeUnsupportedSetting,
115 Setting: "Logprobs",
116 Message: "Logprobs is not supported for reasoning models",
117 })
118 }
119 if providerOptions.TopLogProbs != nil {
120 params.TopLogprobs = param.Opt[int64]{}
121 warnings = append(warnings, ai.CallWarning{
122 Type: ai.CallWarningTypeUnsupportedSetting,
123 Setting: "TopLogprobs",
124 Message: "TopLogprobs is not supported for reasoning models",
125 })
126 }
127 }
128
129 // Handle service tier validation
130 if providerOptions.ServiceTier != nil {
131 serviceTier := *providerOptions.ServiceTier
132 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
133 params.ServiceTier = ""
134 warnings = append(warnings, ai.CallWarning{
135 Type: ai.CallWarningTypeUnsupportedSetting,
136 Setting: "ServiceTier",
137 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
138 })
139 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
140 params.ServiceTier = ""
141 warnings = append(warnings, ai.CallWarning{
142 Type: ai.CallWarningTypeUnsupportedSetting,
143 Setting: "ServiceTier",
144 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
145 })
146 }
147 }
148 return warnings, nil
149}