language_model_hooks.go

  1package openai
  2
  3import (
  4	"fmt"
  5
  6	"github.com/charmbracelet/fantasy/ai"
  7	"github.com/openai/openai-go/v2"
  8	"github.com/openai/openai-go/v2/packages/param"
  9	"github.com/openai/openai-go/v2/shared"
 10)
 11
 12type (
 13	LanguageModelPrepareCallFunc            = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
 14	LanguageModelMapFinishReasonFunc        = func(choice openai.ChatCompletionChoice) ai.FinishReason
 15	LanguageModelUsageFunc                  = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
 16	LanguageModelExtraContentFunc           = func(choice openai.ChatCompletionChoice) []ai.Content
 17	LanguageModelStreamExtraFunc            = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
 18	LanguageModelStreamUsageFunc            = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
 19	LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
 20)
 21
 22func defaultPrepareLanguageModelCall(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 23	if call.ProviderOptions == nil {
 24		return nil, nil
 25	}
 26	var warnings []ai.CallWarning
 27	providerOptions := &ProviderOptions{}
 28	if v, ok := call.ProviderOptions[Name]; ok {
 29		providerOptions, ok = v.(*ProviderOptions)
 30		if !ok {
 31			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 32		}
 33	}
 34
 35	if providerOptions.LogitBias != nil {
 36		params.LogitBias = providerOptions.LogitBias
 37	}
 38	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 39		providerOptions.LogProbs = nil
 40	}
 41	if providerOptions.LogProbs != nil {
 42		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 43	}
 44	if providerOptions.TopLogProbs != nil {
 45		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 46	}
 47	if providerOptions.User != nil {
 48		params.User = param.NewOpt(*providerOptions.User)
 49	}
 50	if providerOptions.ParallelToolCalls != nil {
 51		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 52	}
 53	if providerOptions.MaxCompletionTokens != nil {
 54		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 55	}
 56
 57	if providerOptions.TextVerbosity != nil {
 58		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 59	}
 60	if providerOptions.Prediction != nil {
 61		// Convert map[string]any to ChatCompletionPredictionContentParam
 62		if content, ok := providerOptions.Prediction["content"]; ok {
 63			if contentStr, ok := content.(string); ok {
 64				params.Prediction = openai.ChatCompletionPredictionContentParam{
 65					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 66						OfString: param.NewOpt(contentStr),
 67					},
 68				}
 69			}
 70		}
 71	}
 72	if providerOptions.Store != nil {
 73		params.Store = param.NewOpt(*providerOptions.Store)
 74	}
 75	if providerOptions.Metadata != nil {
 76		// Convert map[string]any to map[string]string
 77		metadata := make(map[string]string)
 78		for k, v := range providerOptions.Metadata {
 79			if str, ok := v.(string); ok {
 80				metadata[k] = str
 81			}
 82		}
 83		params.Metadata = metadata
 84	}
 85	if providerOptions.PromptCacheKey != nil {
 86		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 87	}
 88	if providerOptions.SafetyIdentifier != nil {
 89		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 90	}
 91	if providerOptions.ServiceTier != nil {
 92		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 93	}
 94
 95	if providerOptions.ReasoningEffort != nil {
 96		switch *providerOptions.ReasoningEffort {
 97		case ReasoningEffortMinimal:
 98			params.ReasoningEffort = shared.ReasoningEffortMinimal
 99		case ReasoningEffortLow:
100			params.ReasoningEffort = shared.ReasoningEffortLow
101		case ReasoningEffortMedium:
102			params.ReasoningEffort = shared.ReasoningEffortMedium
103		case ReasoningEffortHigh:
104			params.ReasoningEffort = shared.ReasoningEffortHigh
105		default:
106			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
107		}
108	}
109
110	if isReasoningModel(model.Model()) {
111		if providerOptions.LogitBias != nil {
112			params.LogitBias = nil
113			warnings = append(warnings, ai.CallWarning{
114				Type:    ai.CallWarningTypeUnsupportedSetting,
115				Setting: "LogitBias",
116				Message: "LogitBias is not supported for reasoning models",
117			})
118		}
119		if providerOptions.LogProbs != nil {
120			params.Logprobs = param.Opt[bool]{}
121			warnings = append(warnings, ai.CallWarning{
122				Type:    ai.CallWarningTypeUnsupportedSetting,
123				Setting: "Logprobs",
124				Message: "Logprobs is not supported for reasoning models",
125			})
126		}
127		if providerOptions.TopLogProbs != nil {
128			params.TopLogprobs = param.Opt[int64]{}
129			warnings = append(warnings, ai.CallWarning{
130				Type:    ai.CallWarningTypeUnsupportedSetting,
131				Setting: "TopLogprobs",
132				Message: "TopLogprobs is not supported for reasoning models",
133			})
134		}
135	}
136
137	// Handle service tier validation
138	if providerOptions.ServiceTier != nil {
139		serviceTier := *providerOptions.ServiceTier
140		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
141			params.ServiceTier = ""
142			warnings = append(warnings, ai.CallWarning{
143				Type:    ai.CallWarningTypeUnsupportedSetting,
144				Setting: "ServiceTier",
145				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
146			})
147		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
148			params.ServiceTier = ""
149			warnings = append(warnings, ai.CallWarning{
150				Type:    ai.CallWarningTypeUnsupportedSetting,
151				Setting: "ServiceTier",
152				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
153			})
154		}
155	}
156	return warnings, nil
157}
158
159func defaultMapFinishReason(choice openai.ChatCompletionChoice) ai.FinishReason {
160	finishReason := choice.FinishReason
161	switch finishReason {
162	case "stop":
163		return ai.FinishReasonStop
164	case "length":
165		return ai.FinishReasonLength
166	case "content_filter":
167		return ai.FinishReasonContentFilter
168	case "function_call", "tool_calls":
169		return ai.FinishReasonToolCalls
170	default:
171		return ai.FinishReasonUnknown
172	}
173}
174
175func defaultUsage(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
176	if len(response.Choices) == 0 {
177		return ai.Usage{}, nil
178	}
179	choice := response.Choices[0]
180	completionTokenDetails := response.Usage.CompletionTokensDetails
181	promptTokenDetails := response.Usage.PromptTokensDetails
182
183	// Build provider metadata
184	providerMetadata := &ProviderMetadata{}
185	// Add logprobs if available
186	if len(choice.Logprobs.Content) > 0 {
187		providerMetadata.Logprobs = choice.Logprobs.Content
188	}
189
190	// Add prediction tokens if available
191	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
192		if completionTokenDetails.AcceptedPredictionTokens > 0 {
193			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
194		}
195		if completionTokenDetails.RejectedPredictionTokens > 0 {
196			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
197		}
198	}
199	return ai.Usage{
200		InputTokens:     response.Usage.PromptTokens,
201		OutputTokens:    response.Usage.CompletionTokens,
202		TotalTokens:     response.Usage.TotalTokens,
203		ReasoningTokens: completionTokenDetails.ReasoningTokens,
204		CacheReadTokens: promptTokenDetails.CachedTokens,
205	}, providerMetadata
206}
207
208func defaultStreamUsage(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
209	if chunk.Usage.TotalTokens == 0 {
210		return ai.Usage{}, nil
211	}
212	streamProviderMetadata := &ProviderMetadata{}
213	if metadata != nil {
214		if providerMetadata, ok := metadata[Name]; ok {
215			converted, ok := providerMetadata.(*ProviderMetadata)
216			if ok {
217				streamProviderMetadata = converted
218			}
219		}
220	}
221	// we do this here because the acc does not add prompt details
222	completionTokenDetails := chunk.Usage.CompletionTokensDetails
223	promptTokenDetails := chunk.Usage.PromptTokensDetails
224	usage := ai.Usage{
225		InputTokens:     chunk.Usage.PromptTokens,
226		OutputTokens:    chunk.Usage.CompletionTokens,
227		TotalTokens:     chunk.Usage.TotalTokens,
228		ReasoningTokens: completionTokenDetails.ReasoningTokens,
229		CacheReadTokens: promptTokenDetails.CachedTokens,
230	}
231
232	// Add prediction tokens if available
233	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
234		if completionTokenDetails.AcceptedPredictionTokens > 0 {
235			streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
236		}
237		if completionTokenDetails.RejectedPredictionTokens > 0 {
238			streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
239		}
240	}
241
242	return usage, ai.ProviderMetadata{
243		Name: streamProviderMetadata,
244	}
245}
246
247func defaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
248	streamProviderMetadata, ok := metadata[Name]
249	if !ok {
250		streamProviderMetadata = &ProviderMetadata{}
251	}
252	if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
253		converted.Logprobs = choice.Logprobs.Content
254		metadata[Name] = converted
255	}
256	return metadata
257}