1package openai
  2
  3import (
  4	"fmt"
  5
  6	"github.com/charmbracelet/fantasy/ai"
  7	"github.com/openai/openai-go/v2"
  8	"github.com/openai/openai-go/v2/packages/param"
  9	"github.com/openai/openai-go/v2/shared"
 10)
 11
 12type (
 13	LanguageModelPrepareCallFunc            = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
 14	LanguageModelMapFinishReasonFunc        = func(finishReason string) ai.FinishReason
 15	LanguageModelUsageFunc                  = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
 16	LanguageModelExtraContentFunc           = func(choice openai.ChatCompletionChoice) []ai.Content
 17	LanguageModelStreamExtraFunc            = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
 18	LanguageModelStreamUsageFunc            = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
 19	LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
 20)
 21
 22func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 23	if call.ProviderOptions == nil {
 24		return nil, nil
 25	}
 26	var warnings []ai.CallWarning
 27	providerOptions := &ProviderOptions{}
 28	if v, ok := call.ProviderOptions[Name]; ok {
 29		providerOptions, ok = v.(*ProviderOptions)
 30		if !ok {
 31			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 32		}
 33	}
 34
 35	if providerOptions.LogitBias != nil {
 36		params.LogitBias = providerOptions.LogitBias
 37	}
 38	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 39		providerOptions.LogProbs = nil
 40	}
 41	if providerOptions.LogProbs != nil {
 42		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 43	}
 44	if providerOptions.TopLogProbs != nil {
 45		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 46	}
 47	if providerOptions.User != nil {
 48		params.User = param.NewOpt(*providerOptions.User)
 49	}
 50	if providerOptions.ParallelToolCalls != nil {
 51		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 52	}
 53	if providerOptions.MaxCompletionTokens != nil {
 54		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 55	}
 56
 57	if providerOptions.TextVerbosity != nil {
 58		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 59	}
 60	if providerOptions.Prediction != nil {
 61		// Convert map[string]any to ChatCompletionPredictionContentParam
 62		if content, ok := providerOptions.Prediction["content"]; ok {
 63			if contentStr, ok := content.(string); ok {
 64				params.Prediction = openai.ChatCompletionPredictionContentParam{
 65					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 66						OfString: param.NewOpt(contentStr),
 67					},
 68				}
 69			}
 70		}
 71	}
 72	if providerOptions.Store != nil {
 73		params.Store = param.NewOpt(*providerOptions.Store)
 74	}
 75	if providerOptions.Metadata != nil {
 76		// Convert map[string]any to map[string]string
 77		metadata := make(map[string]string)
 78		for k, v := range providerOptions.Metadata {
 79			if str, ok := v.(string); ok {
 80				metadata[k] = str
 81			}
 82		}
 83		params.Metadata = metadata
 84	}
 85	if providerOptions.PromptCacheKey != nil {
 86		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 87	}
 88	if providerOptions.SafetyIdentifier != nil {
 89		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 90	}
 91	if providerOptions.ServiceTier != nil {
 92		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 93	}
 94
 95	if providerOptions.ReasoningEffort != nil {
 96		switch *providerOptions.ReasoningEffort {
 97		case ReasoningEffortMinimal:
 98			params.ReasoningEffort = shared.ReasoningEffortMinimal
 99		case ReasoningEffortLow:
100			params.ReasoningEffort = shared.ReasoningEffortLow
101		case ReasoningEffortMedium:
102			params.ReasoningEffort = shared.ReasoningEffortMedium
103		case ReasoningEffortHigh:
104			params.ReasoningEffort = shared.ReasoningEffortHigh
105		default:
106			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
107		}
108	}
109
110	if isReasoningModel(model.Model()) {
111		if providerOptions.LogitBias != nil {
112			params.LogitBias = nil
113			warnings = append(warnings, ai.CallWarning{
114				Type:    ai.CallWarningTypeUnsupportedSetting,
115				Setting: "LogitBias",
116				Message: "LogitBias is not supported for reasoning models",
117			})
118		}
119		if providerOptions.LogProbs != nil {
120			params.Logprobs = param.Opt[bool]{}
121			warnings = append(warnings, ai.CallWarning{
122				Type:    ai.CallWarningTypeUnsupportedSetting,
123				Setting: "Logprobs",
124				Message: "Logprobs is not supported for reasoning models",
125			})
126		}
127		if providerOptions.TopLogProbs != nil {
128			params.TopLogprobs = param.Opt[int64]{}
129			warnings = append(warnings, ai.CallWarning{
130				Type:    ai.CallWarningTypeUnsupportedSetting,
131				Setting: "TopLogprobs",
132				Message: "TopLogprobs is not supported for reasoning models",
133			})
134		}
135	}
136
137	// Handle service tier validation
138	if providerOptions.ServiceTier != nil {
139		serviceTier := *providerOptions.ServiceTier
140		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
141			params.ServiceTier = ""
142			warnings = append(warnings, ai.CallWarning{
143				Type:    ai.CallWarningTypeUnsupportedSetting,
144				Setting: "ServiceTier",
145				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
146			})
147		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
148			params.ServiceTier = ""
149			warnings = append(warnings, ai.CallWarning{
150				Type:    ai.CallWarningTypeUnsupportedSetting,
151				Setting: "ServiceTier",
152				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
153			})
154		}
155	}
156	return warnings, nil
157}
158
159func DefaultMapFinishReasonFunc(finishReason string) ai.FinishReason {
160	switch finishReason {
161	case "stop":
162		return ai.FinishReasonStop
163	case "length":
164		return ai.FinishReasonLength
165	case "content_filter":
166		return ai.FinishReasonContentFilter
167	case "function_call", "tool_calls":
168		return ai.FinishReasonToolCalls
169	default:
170		return ai.FinishReasonUnknown
171	}
172}
173
174func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
175	completionTokenDetails := response.Usage.CompletionTokensDetails
176	promptTokenDetails := response.Usage.PromptTokensDetails
177
178	// Build provider metadata
179	providerMetadata := &ProviderMetadata{}
180
181	// Add logprobs if available
182	if len(response.Choices) > 0 && len(response.Choices[0].Logprobs.Content) > 0 {
183		providerMetadata.Logprobs = response.Choices[0].Logprobs.Content
184	}
185
186	// Add prediction tokens if available
187	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
188		if completionTokenDetails.AcceptedPredictionTokens > 0 {
189			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
190		}
191		if completionTokenDetails.RejectedPredictionTokens > 0 {
192			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
193		}
194	}
195	return ai.Usage{
196		InputTokens:     response.Usage.PromptTokens,
197		OutputTokens:    response.Usage.CompletionTokens,
198		TotalTokens:     response.Usage.TotalTokens,
199		ReasoningTokens: completionTokenDetails.ReasoningTokens,
200		CacheReadTokens: promptTokenDetails.CachedTokens,
201	}, providerMetadata
202}
203
204func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
205	if chunk.Usage.TotalTokens == 0 {
206		return ai.Usage{}, nil
207	}
208	streamProviderMetadata := &ProviderMetadata{}
209	if metadata != nil {
210		if providerMetadata, ok := metadata[Name]; ok {
211			converted, ok := providerMetadata.(*ProviderMetadata)
212			if ok {
213				streamProviderMetadata = converted
214			}
215		}
216	}
217	// we do this here because the acc does not add prompt details
218	completionTokenDetails := chunk.Usage.CompletionTokensDetails
219	promptTokenDetails := chunk.Usage.PromptTokensDetails
220	usage := ai.Usage{
221		InputTokens:     chunk.Usage.PromptTokens,
222		OutputTokens:    chunk.Usage.CompletionTokens,
223		TotalTokens:     chunk.Usage.TotalTokens,
224		ReasoningTokens: completionTokenDetails.ReasoningTokens,
225		CacheReadTokens: promptTokenDetails.CachedTokens,
226	}
227
228	// Add prediction tokens if available
229	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
230		if completionTokenDetails.AcceptedPredictionTokens > 0 {
231			streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
232		}
233		if completionTokenDetails.RejectedPredictionTokens > 0 {
234			streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
235		}
236	}
237
238	return usage, ai.ProviderMetadata{
239		Name: streamProviderMetadata,
240	}
241}
242
243func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
244	streamProviderMetadata, ok := metadata[Name]
245	if !ok {
246		streamProviderMetadata = &ProviderMetadata{}
247	}
248	if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
249		converted.Logprobs = choice.Logprobs.Content
250		metadata[Name] = converted
251	}
252	return metadata
253}