language_model_hooks.go

  1package openai
  2
  3import (
  4	"fmt"
  5
  6	"github.com/charmbracelet/fantasy/ai"
  7	"github.com/google/uuid"
  8	"github.com/openai/openai-go/v2"
  9	"github.com/openai/openai-go/v2/packages/param"
 10	"github.com/openai/openai-go/v2/shared"
 11)
 12
 13type (
 14	LanguageModelGenerateIDFunc             = func() string
 15	LanguageModelPrepareCallFunc            = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
 16	LanguageModelMapFinishReasonFunc        = func(choice openai.ChatCompletionChoice) ai.FinishReason
 17	LanguageModelUsageFunc                  = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
 18	LanguageModelExtraContentFunc           = func(choice openai.ChatCompletionChoice) []ai.Content
 19	LanguageModelStreamExtraFunc            = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
 20	LanguageModelStreamUsageFunc            = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
 21	LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
 22)
 23
 24func DefaultGenerateID() string {
 25	return uuid.NewString()
 26}
 27
 28func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 29	if call.ProviderOptions == nil {
 30		return nil, nil
 31	}
 32	var warnings []ai.CallWarning
 33	providerOptions := &ProviderOptions{}
 34	if v, ok := call.ProviderOptions[Name]; ok {
 35		providerOptions, ok = v.(*ProviderOptions)
 36		if !ok {
 37			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 38		}
 39	}
 40
 41	if providerOptions.LogitBias != nil {
 42		params.LogitBias = providerOptions.LogitBias
 43	}
 44	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 45		providerOptions.LogProbs = nil
 46	}
 47	if providerOptions.LogProbs != nil {
 48		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 49	}
 50	if providerOptions.TopLogProbs != nil {
 51		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 52	}
 53	if providerOptions.User != nil {
 54		params.User = param.NewOpt(*providerOptions.User)
 55	}
 56	if providerOptions.ParallelToolCalls != nil {
 57		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 58	}
 59	if providerOptions.MaxCompletionTokens != nil {
 60		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 61	}
 62
 63	if providerOptions.TextVerbosity != nil {
 64		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 65	}
 66	if providerOptions.Prediction != nil {
 67		// Convert map[string]any to ChatCompletionPredictionContentParam
 68		if content, ok := providerOptions.Prediction["content"]; ok {
 69			if contentStr, ok := content.(string); ok {
 70				params.Prediction = openai.ChatCompletionPredictionContentParam{
 71					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 72						OfString: param.NewOpt(contentStr),
 73					},
 74				}
 75			}
 76		}
 77	}
 78	if providerOptions.Store != nil {
 79		params.Store = param.NewOpt(*providerOptions.Store)
 80	}
 81	if providerOptions.Metadata != nil {
 82		// Convert map[string]any to map[string]string
 83		metadata := make(map[string]string)
 84		for k, v := range providerOptions.Metadata {
 85			if str, ok := v.(string); ok {
 86				metadata[k] = str
 87			}
 88		}
 89		params.Metadata = metadata
 90	}
 91	if providerOptions.PromptCacheKey != nil {
 92		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 93	}
 94	if providerOptions.SafetyIdentifier != nil {
 95		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 96	}
 97	if providerOptions.ServiceTier != nil {
 98		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 99	}
100
101	if providerOptions.ReasoningEffort != nil {
102		switch *providerOptions.ReasoningEffort {
103		case ReasoningEffortMinimal:
104			params.ReasoningEffort = shared.ReasoningEffortMinimal
105		case ReasoningEffortLow:
106			params.ReasoningEffort = shared.ReasoningEffortLow
107		case ReasoningEffortMedium:
108			params.ReasoningEffort = shared.ReasoningEffortMedium
109		case ReasoningEffortHigh:
110			params.ReasoningEffort = shared.ReasoningEffortHigh
111		default:
112			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
113		}
114	}
115
116	if isReasoningModel(model.Model()) {
117		if providerOptions.LogitBias != nil {
118			params.LogitBias = nil
119			warnings = append(warnings, ai.CallWarning{
120				Type:    ai.CallWarningTypeUnsupportedSetting,
121				Setting: "LogitBias",
122				Message: "LogitBias is not supported for reasoning models",
123			})
124		}
125		if providerOptions.LogProbs != nil {
126			params.Logprobs = param.Opt[bool]{}
127			warnings = append(warnings, ai.CallWarning{
128				Type:    ai.CallWarningTypeUnsupportedSetting,
129				Setting: "Logprobs",
130				Message: "Logprobs is not supported for reasoning models",
131			})
132		}
133		if providerOptions.TopLogProbs != nil {
134			params.TopLogprobs = param.Opt[int64]{}
135			warnings = append(warnings, ai.CallWarning{
136				Type:    ai.CallWarningTypeUnsupportedSetting,
137				Setting: "TopLogprobs",
138				Message: "TopLogprobs is not supported for reasoning models",
139			})
140		}
141	}
142
143	// Handle service tier validation
144	if providerOptions.ServiceTier != nil {
145		serviceTier := *providerOptions.ServiceTier
146		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
147			params.ServiceTier = ""
148			warnings = append(warnings, ai.CallWarning{
149				Type:    ai.CallWarningTypeUnsupportedSetting,
150				Setting: "ServiceTier",
151				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
152			})
153		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
154			params.ServiceTier = ""
155			warnings = append(warnings, ai.CallWarning{
156				Type:    ai.CallWarningTypeUnsupportedSetting,
157				Setting: "ServiceTier",
158				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
159			})
160		}
161	}
162	return warnings, nil
163}
164
165func DefaultMapFinishReasonFunc(choice openai.ChatCompletionChoice) ai.FinishReason {
166	finishReason := choice.FinishReason
167	switch finishReason {
168	case "stop":
169		return ai.FinishReasonStop
170	case "length":
171		return ai.FinishReasonLength
172	case "content_filter":
173		return ai.FinishReasonContentFilter
174	case "function_call", "tool_calls":
175		return ai.FinishReasonToolCalls
176	default:
177		return ai.FinishReasonUnknown
178	}
179}
180
181func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
182	if len(response.Choices) == 0 {
183		return ai.Usage{}, nil
184	}
185	choice := response.Choices[0]
186	completionTokenDetails := response.Usage.CompletionTokensDetails
187	promptTokenDetails := response.Usage.PromptTokensDetails
188
189	// Build provider metadata
190	providerMetadata := &ProviderMetadata{}
191	// Add logprobs if available
192	if len(choice.Logprobs.Content) > 0 {
193		providerMetadata.Logprobs = choice.Logprobs.Content
194	}
195
196	// Add prediction tokens if available
197	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
198		if completionTokenDetails.AcceptedPredictionTokens > 0 {
199			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
200		}
201		if completionTokenDetails.RejectedPredictionTokens > 0 {
202			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
203		}
204	}
205	return ai.Usage{
206		InputTokens:     response.Usage.PromptTokens,
207		OutputTokens:    response.Usage.CompletionTokens,
208		TotalTokens:     response.Usage.TotalTokens,
209		ReasoningTokens: completionTokenDetails.ReasoningTokens,
210		CacheReadTokens: promptTokenDetails.CachedTokens,
211	}, providerMetadata
212}
213
214func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
215	if chunk.Usage.TotalTokens == 0 {
216		return ai.Usage{}, nil
217	}
218	streamProviderMetadata := &ProviderMetadata{}
219	if metadata != nil {
220		if providerMetadata, ok := metadata[Name]; ok {
221			converted, ok := providerMetadata.(*ProviderMetadata)
222			if ok {
223				streamProviderMetadata = converted
224			}
225		}
226	}
227	// we do this here because the acc does not add prompt details
228	completionTokenDetails := chunk.Usage.CompletionTokensDetails
229	promptTokenDetails := chunk.Usage.PromptTokensDetails
230	usage := ai.Usage{
231		InputTokens:     chunk.Usage.PromptTokens,
232		OutputTokens:    chunk.Usage.CompletionTokens,
233		TotalTokens:     chunk.Usage.TotalTokens,
234		ReasoningTokens: completionTokenDetails.ReasoningTokens,
235		CacheReadTokens: promptTokenDetails.CachedTokens,
236	}
237
238	// Add prediction tokens if available
239	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
240		if completionTokenDetails.AcceptedPredictionTokens > 0 {
241			streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
242		}
243		if completionTokenDetails.RejectedPredictionTokens > 0 {
244			streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
245		}
246	}
247
248	return usage, ai.ProviderMetadata{
249		Name: streamProviderMetadata,
250	}
251}
252
253func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
254	streamProviderMetadata, ok := metadata[Name]
255	if !ok {
256		streamProviderMetadata = &ProviderMetadata{}
257	}
258	if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
259		converted.Logprobs = choice.Logprobs.Content
260		metadata[Name] = converted
261	}
262	return metadata
263}