language_model_hooks.go

  1package openai
  2
  3import (
  4	"fmt"
  5
  6	"charm.land/fantasy"
  7	"github.com/openai/openai-go/v2"
  8	"github.com/openai/openai-go/v2/packages/param"
  9	"github.com/openai/openai-go/v2/shared"
 10)
 11
 12// LanguageModelPrepareCallFunc is a function that prepares the call for the language model.
 13type LanguageModelPrepareCallFunc = func(model fantasy.LanguageModel, params *openai.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error)
 14
 15// LanguageModelMapFinishReasonFunc is a function that maps the finish reason for the language model.
 16type LanguageModelMapFinishReasonFunc = func(finishReason string) fantasy.FinishReason
 17
 18// LanguageModelUsageFunc is a function that calculates usage for the language model.
 19type LanguageModelUsageFunc = func(choice openai.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData)
 20
 21// LanguageModelExtraContentFunc is a function that adds extra content for the language model.
 22type LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []fantasy.Content
 23
 24// LanguageModelStreamExtraFunc is a function that handles stream extra functionality for the language model.
 25type LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
 26
 27// LanguageModelStreamUsageFunc is a function that calculates stream usage for the language model.
 28type LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata)
 29
 30// LanguageModelStreamProviderMetadataFunc is a function that handles stream provider metadata for the language model.
 31type LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata fantasy.ProviderMetadata) fantasy.ProviderMetadata
 32
 33// DefaultPrepareCallFunc is the default implementation for preparing a call to the language model.
 34func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 35	if call.ProviderOptions == nil {
 36		return nil, nil
 37	}
 38	var warnings []fantasy.CallWarning
 39	providerOptions := &ProviderOptions{}
 40	if v, ok := call.ProviderOptions[Name]; ok {
 41		providerOptions, ok = v.(*ProviderOptions)
 42		if !ok {
 43			return nil, fantasy.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 44		}
 45	}
 46
 47	if providerOptions.LogitBias != nil {
 48		params.LogitBias = providerOptions.LogitBias
 49	}
 50	if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
 51		providerOptions.LogProbs = nil
 52	}
 53	if providerOptions.LogProbs != nil {
 54		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 55	}
 56	if providerOptions.TopLogProbs != nil {
 57		params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
 58	}
 59	if providerOptions.User != nil {
 60		params.User = param.NewOpt(*providerOptions.User)
 61	}
 62	if providerOptions.ParallelToolCalls != nil {
 63		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 64	}
 65	if providerOptions.MaxCompletionTokens != nil {
 66		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 67	}
 68
 69	if providerOptions.TextVerbosity != nil {
 70		params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
 71	}
 72	if providerOptions.Prediction != nil {
 73		// Convert map[string]any to ChatCompletionPredictionContentParam
 74		if content, ok := providerOptions.Prediction["content"]; ok {
 75			if contentStr, ok := content.(string); ok {
 76				params.Prediction = openai.ChatCompletionPredictionContentParam{
 77					Content: openai.ChatCompletionPredictionContentContentUnionParam{
 78						OfString: param.NewOpt(contentStr),
 79					},
 80				}
 81			}
 82		}
 83	}
 84	if providerOptions.Store != nil {
 85		params.Store = param.NewOpt(*providerOptions.Store)
 86	}
 87	if providerOptions.Metadata != nil {
 88		// Convert map[string]any to map[string]string
 89		metadata := make(map[string]string)
 90		for k, v := range providerOptions.Metadata {
 91			if str, ok := v.(string); ok {
 92				metadata[k] = str
 93			}
 94		}
 95		params.Metadata = metadata
 96	}
 97	if providerOptions.PromptCacheKey != nil {
 98		params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
 99	}
100	if providerOptions.SafetyIdentifier != nil {
101		params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
102	}
103	if providerOptions.ServiceTier != nil {
104		params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
105	}
106
107	if providerOptions.ReasoningEffort != nil {
108		switch *providerOptions.ReasoningEffort {
109		case ReasoningEffortMinimal:
110			params.ReasoningEffort = shared.ReasoningEffortMinimal
111		case ReasoningEffortLow:
112			params.ReasoningEffort = shared.ReasoningEffortLow
113		case ReasoningEffortMedium:
114			params.ReasoningEffort = shared.ReasoningEffortMedium
115		case ReasoningEffortHigh:
116			params.ReasoningEffort = shared.ReasoningEffortHigh
117		default:
118			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
119		}
120	}
121
122	if isReasoningModel(model.Model()) {
123		if providerOptions.LogitBias != nil {
124			params.LogitBias = nil
125			warnings = append(warnings, fantasy.CallWarning{
126				Type:    fantasy.CallWarningTypeUnsupportedSetting,
127				Setting: "LogitBias",
128				Message: "LogitBias is not supported for reasoning models",
129			})
130		}
131		if providerOptions.LogProbs != nil {
132			params.Logprobs = param.Opt[bool]{}
133			warnings = append(warnings, fantasy.CallWarning{
134				Type:    fantasy.CallWarningTypeUnsupportedSetting,
135				Setting: "Logprobs",
136				Message: "Logprobs is not supported for reasoning models",
137			})
138		}
139		if providerOptions.TopLogProbs != nil {
140			params.TopLogprobs = param.Opt[int64]{}
141			warnings = append(warnings, fantasy.CallWarning{
142				Type:    fantasy.CallWarningTypeUnsupportedSetting,
143				Setting: "TopLogprobs",
144				Message: "TopLogprobs is not supported for reasoning models",
145			})
146		}
147	}
148
149	// Handle service tier validation
150	if providerOptions.ServiceTier != nil {
151		serviceTier := *providerOptions.ServiceTier
152		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
153			params.ServiceTier = ""
154			warnings = append(warnings, fantasy.CallWarning{
155				Type:    fantasy.CallWarningTypeUnsupportedSetting,
156				Setting: "ServiceTier",
157				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
158			})
159		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
160			params.ServiceTier = ""
161			warnings = append(warnings, fantasy.CallWarning{
162				Type:    fantasy.CallWarningTypeUnsupportedSetting,
163				Setting: "ServiceTier",
164				Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
165			})
166		}
167	}
168	return warnings, nil
169}
170
171// DefaultMapFinishReasonFunc is the default implementation for mapping finish reasons.
172func DefaultMapFinishReasonFunc(finishReason string) fantasy.FinishReason {
173	switch finishReason {
174	case "stop":
175		return fantasy.FinishReasonStop
176	case "length":
177		return fantasy.FinishReasonLength
178	case "content_filter":
179		return fantasy.FinishReasonContentFilter
180	case "function_call", "tool_calls":
181		return fantasy.FinishReasonToolCalls
182	default:
183		return fantasy.FinishReasonUnknown
184	}
185}
186
187// DefaultUsageFunc is the default implementation for calculating usage.
188func DefaultUsageFunc(response openai.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData) {
189	completionTokenDetails := response.Usage.CompletionTokensDetails
190	promptTokenDetails := response.Usage.PromptTokensDetails
191
192	// Build provider metadata
193	providerMetadata := &ProviderMetadata{}
194
195	// Add logprobs if available
196	if len(response.Choices) > 0 && len(response.Choices[0].Logprobs.Content) > 0 {
197		providerMetadata.Logprobs = response.Choices[0].Logprobs.Content
198	}
199
200	// Add prediction tokens if available
201	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
202		if completionTokenDetails.AcceptedPredictionTokens > 0 {
203			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
204		}
205		if completionTokenDetails.RejectedPredictionTokens > 0 {
206			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
207		}
208	}
209	return fantasy.Usage{
210		InputTokens:     response.Usage.PromptTokens,
211		OutputTokens:    response.Usage.CompletionTokens,
212		TotalTokens:     response.Usage.TotalTokens,
213		ReasoningTokens: completionTokenDetails.ReasoningTokens,
214		CacheReadTokens: promptTokenDetails.CachedTokens,
215	}, providerMetadata
216}
217
218// DefaultStreamUsageFunc is the default implementation for calculating stream usage.
219func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, _ map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata) {
220	if chunk.Usage.TotalTokens == 0 {
221		return fantasy.Usage{}, nil
222	}
223	streamProviderMetadata := &ProviderMetadata{}
224	if metadata != nil {
225		if providerMetadata, ok := metadata[Name]; ok {
226			converted, ok := providerMetadata.(*ProviderMetadata)
227			if ok {
228				streamProviderMetadata = converted
229			}
230		}
231	}
232	// we do this here because the acc does not add prompt details
233	completionTokenDetails := chunk.Usage.CompletionTokensDetails
234	promptTokenDetails := chunk.Usage.PromptTokensDetails
235	usage := fantasy.Usage{
236		InputTokens:     chunk.Usage.PromptTokens,
237		OutputTokens:    chunk.Usage.CompletionTokens,
238		TotalTokens:     chunk.Usage.TotalTokens,
239		ReasoningTokens: completionTokenDetails.ReasoningTokens,
240		CacheReadTokens: promptTokenDetails.CachedTokens,
241	}
242
243	// Add prediction tokens if available
244	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
245		if completionTokenDetails.AcceptedPredictionTokens > 0 {
246			streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
247		}
248		if completionTokenDetails.RejectedPredictionTokens > 0 {
249			streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
250		}
251	}
252
253	return usage, fantasy.ProviderMetadata{
254		Name: streamProviderMetadata,
255	}
256}
257
258// DefaultStreamProviderMetadataFunc is the default implementation for handling stream provider metadata.
259func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata fantasy.ProviderMetadata) fantasy.ProviderMetadata {
260	streamProviderMetadata, ok := metadata[Name]
261	if !ok {
262		streamProviderMetadata = &ProviderMetadata{}
263	}
264	if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
265		converted.Logprobs = choice.Logprobs.Content
266		metadata[Name] = converted
267	}
268	return metadata
269}