language_model_hooks.go

  1package openai
  2
  3import (
  4	"fmt"
  5
  6	"github.com/charmbracelet/fantasy/ai"
  7	"github.com/google/uuid"
  8	"github.com/openai/openai-go/v2"
  9	"github.com/openai/openai-go/v2/packages/param"
 10	"github.com/openai/openai-go/v2/shared"
 11)
 12
 13type (
 14	LanguageModelGenerateIDFunc             = func() string
 15	LanguageModelPrepareCallFunc            = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
 16	LanguageModelMapFinishReasonFunc        = func(choice openai.ChatCompletionChoice) ai.FinishReason
 17	LanguageModelUsageFunc                  = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
 18	LanguageModelExtraContentFunc           = func(choice openai.ChatCompletionChoice) []ai.Content
 19	LanguageModelStreamExtraFunc            = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
 20	LanguageModelStreamUsageFunc            = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
 21	LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
 22)
 23
 24func DefaultGenerateID() string {
 25	return uuid.NewString()
 26}
 27
 28func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 29	if call.ProviderOptions == nil {
 30		return nil, nil
 31	}
 32	var warnings []ai.CallWarning
 33	providerOptions := &ProviderOptions{}
 34	if v, ok := call.ProviderOptions[Name]; ok {
 35		providerOptions, ok = v.(*ProviderOptions)
 36		if !ok {
 37			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 38		}
 39	}
 40
 41	if providerOptions.LogitBias != nil {
 42		params.LogitBias = providerOptions.LogitBias
 43	}
 44	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 45		providerOptions.LogProbs = nil
 46	}
 47	if providerOptions.LogProbs != nil {
 48		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 49	}
 50	if providerOptions.TopLogProbs != nil {
 51		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 52	}
 53	if providerOptions.User != nil {
 54		params.User = param.NewOpt(*providerOptions.User)
 55	}
 56	if providerOptions.ParallelToolCalls != nil {
 57		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 58	}
 59	if providerOptions.MaxCompletionTokens != nil {
 60		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 61	}
 62
 63	if providerOptions.TextVerbosity != nil {
 64		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 65	}
 66	if providerOptions.Prediction != nil {
 67		// Convert map[string]any to ChatCompletionPredictionContentParam
 68		if content, ok := providerOptions.Prediction["content"]; ok {
 69			if contentStr, ok := content.(string); ok {
 70				params.Prediction = openai.ChatCompletionPredictionContentParam{
 71					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 72						OfString: param.NewOpt(contentStr),
 73					},
 74				}
 75			}
 76		}
 77	}
 78	if providerOptions.Store != nil {
 79		params.Store = param.NewOpt(*providerOptions.Store)
 80	}
 81	if providerOptions.Metadata != nil {
 82		// Convert map[string]any to map[string]string
 83		metadata := make(map[string]string)
 84		for k, v := range providerOptions.Metadata {
 85			if str, ok := v.(string); ok {
 86				metadata[k] = str
 87			}
 88		}
 89		params.Metadata = metadata
 90	}
 91	if providerOptions.PromptCacheKey != nil {
 92		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 93	}
 94	if providerOptions.SafetyIdentifier != nil {
 95		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
 96	}
 97	if providerOptions.ServiceTier != nil {
 98		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
 99	}
100
101	if providerOptions.ReasoningEffort != nil {
102		switch *providerOptions.ReasoningEffort {
103		case ReasoningEffortMinimal:
104			params.ReasoningEffort = shared.ReasoningEffortMinimal
105		case ReasoningEffortLow:
106			params.ReasoningEffort = shared.ReasoningEffortLow
107		case ReasoningEffortMedium:
108			params.ReasoningEffort = shared.ReasoningEffortMedium
109		case ReasoningEffortHigh:
110			params.ReasoningEffort = shared.ReasoningEffortHigh
111		default:
112			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
113		}
114	}
115
116	if isReasoningModel(model.Model()) {
117		if providerOptions.LogitBias != nil {
118			params.LogitBias = nil
119			warnings = append(warnings, ai.CallWarning{
120				Type:    ai.CallWarningTypeUnsupportedSetting,
121				Setting: "LogitBias",
122				Message: "LogitBias is not supported for reasoning models",
123			})
124		}
125		if providerOptions.LogProbs != nil {
126			params.Logprobs = param.Opt[bool]{}
127			warnings = append(warnings, ai.CallWarning{
128				Type:    ai.CallWarningTypeUnsupportedSetting,
129				Setting: "Logprobs",
130				Message: "Logprobs is not supported for reasoning models",
131			})
132		}
133		if providerOptions.TopLogProbs != nil {
134			params.TopLogprobs = param.Opt[int64]{}
135			warnings = append(warnings, ai.CallWarning{
136				Type:    ai.CallWarningTypeUnsupportedSetting,
137				Setting: "TopLogprobs",
138				Message: "TopLogprobs is not supported for reasoning models",
139			})
140		}
141	}
142
143	// Handle service tier validation
144	if providerOptions.ServiceTier != nil {
145		serviceTier := *providerOptions.ServiceTier
146		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
147			params.ServiceTier = ""
148			warnings = append(warnings, ai.CallWarning{
149				Type:    ai.CallWarningTypeUnsupportedSetting,
150				Setting: "ServiceTier",
151				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
152			})
153		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
154			params.ServiceTier = ""
155			warnings = append(warnings, ai.CallWarning{
156				Type:    ai.CallWarningTypeUnsupportedSetting,
157				Setting: "ServiceTier",
158				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
159			})
160		}
161	}
162	return warnings, nil
163}
164
165func DefaultMapFinishReasonFunc(choice openai.ChatCompletionChoice) ai.FinishReason {
166	finishReason := choice.FinishReason
167	switch finishReason {
168	case "stop":
169		return ai.FinishReasonStop
170	case "length":
171		return ai.FinishReasonLength
172	case "content_filter":
173		return ai.FinishReasonContentFilter
174	case "function_call", "tool_calls":
175		return ai.FinishReasonToolCalls
176	default:
177		return ai.FinishReasonUnknown
178	}
179}
180
181func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
182	completionTokenDetails := response.Usage.CompletionTokensDetails
183	promptTokenDetails := response.Usage.PromptTokensDetails
184
185	// Build provider metadata
186	providerMetadata := &ProviderMetadata{}
187
188	// Add logprobs if available
189	if len(response.Choices) > 0 && len(response.Choices[0].Logprobs.Content) > 0 {
190		providerMetadata.Logprobs = response.Choices[0].Logprobs.Content
191	}
192
193	// Add prediction tokens if available
194	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
195		if completionTokenDetails.AcceptedPredictionTokens > 0 {
196			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
197		}
198		if completionTokenDetails.RejectedPredictionTokens > 0 {
199			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
200		}
201	}
202	return ai.Usage{
203		InputTokens:     response.Usage.PromptTokens,
204		OutputTokens:    response.Usage.CompletionTokens,
205		TotalTokens:     response.Usage.TotalTokens,
206		ReasoningTokens: completionTokenDetails.ReasoningTokens,
207		CacheReadTokens: promptTokenDetails.CachedTokens,
208	}, providerMetadata
209}
210
211func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
212	if chunk.Usage.TotalTokens == 0 {
213		return ai.Usage{}, nil
214	}
215	streamProviderMetadata := &ProviderMetadata{}
216	if metadata != nil {
217		if providerMetadata, ok := metadata[Name]; ok {
218			converted, ok := providerMetadata.(*ProviderMetadata)
219			if ok {
220				streamProviderMetadata = converted
221			}
222		}
223	}
224	// we do this here because the acc does not add prompt details
225	completionTokenDetails := chunk.Usage.CompletionTokensDetails
226	promptTokenDetails := chunk.Usage.PromptTokensDetails
227	usage := ai.Usage{
228		InputTokens:     chunk.Usage.PromptTokens,
229		OutputTokens:    chunk.Usage.CompletionTokens,
230		TotalTokens:     chunk.Usage.TotalTokens,
231		ReasoningTokens: completionTokenDetails.ReasoningTokens,
232		CacheReadTokens: promptTokenDetails.CachedTokens,
233	}
234
235	// Add prediction tokens if available
236	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
237		if completionTokenDetails.AcceptedPredictionTokens > 0 {
238			streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
239		}
240		if completionTokenDetails.RejectedPredictionTokens > 0 {
241			streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
242		}
243	}
244
245	return usage, ai.ProviderMetadata{
246		Name: streamProviderMetadata,
247	}
248}
249
250func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
251	streamProviderMetadata, ok := metadata[Name]
252	if !ok {
253		streamProviderMetadata = &ProviderMetadata{}
254	}
255	if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
256		converted.Logprobs = choice.Logprobs.Content
257		metadata[Name] = converted
258	}
259	return metadata
260}