language_model_hooks.go

  1package openrouter
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"maps"
  7
  8	"charm.land/fantasy"
  9	"charm.land/fantasy/anthropic"
 10	openaisdk "github.com/openai/openai-go/v2"
 11	"github.com/openai/openai-go/v2/packages/param"
 12)
 13
 14const reasoningStartedCtx = "reasoning_started"
 15
 16func languagePrepareModelCall(model fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
 17	providerOptions := &ProviderOptions{}
 18	if v, ok := call.ProviderOptions[Name]; ok {
 19		providerOptions, ok = v.(*ProviderOptions)
 20		if !ok {
 21			return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
 22		}
 23	}
 24
 25	extraFields := make(map[string]any)
 26
 27	if providerOptions.Provider != nil {
 28		data, err := structToMapJSON(providerOptions.Provider)
 29		if err != nil {
 30			return nil, err
 31		}
 32		extraFields["provider"] = data
 33	}
 34
 35	if providerOptions.Reasoning != nil {
 36		data, err := structToMapJSON(providerOptions.Reasoning)
 37		if err != nil {
 38			return nil, err
 39		}
 40		extraFields["reasoning"] = data
 41	}
 42
 43	if providerOptions.IncludeUsage != nil {
 44		extraFields["usage"] = map[string]any{
 45			"include": *providerOptions.IncludeUsage,
 46		}
 47	} else { // default include usage
 48		extraFields["usage"] = map[string]any{
 49			"include": true,
 50		}
 51	}
 52	if providerOptions.LogitBias != nil {
 53		params.LogitBias = providerOptions.LogitBias
 54	}
 55	if providerOptions.LogProbs != nil {
 56		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 57	}
 58	if providerOptions.User != nil {
 59		params.User = param.NewOpt(*providerOptions.User)
 60	}
 61	if providerOptions.ParallelToolCalls != nil {
 62		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 63	}
 64
 65	maps.Copy(extraFields, providerOptions.ExtraBody)
 66	params.SetExtraFields(extraFields)
 67	return nil, nil
 68}
 69
 70func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
 71	var content []fantasy.Content
 72	reasoningData := ReasoningData{}
 73	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 74	if err != nil {
 75		return content
 76	}
 77	for _, detail := range reasoningData.ReasoningDetails {
 78		var metadata fantasy.ProviderMetadata
 79
 80		if detail.Signature != "" {
 81			metadata = fantasy.ProviderMetadata{
 82				Name: &ReasoningMetadata{
 83					Signature: detail.Signature,
 84				},
 85				anthropic.Name: &anthropic.ReasoningOptionMetadata{
 86					Signature: detail.Signature,
 87				},
 88			}
 89		}
 90		switch detail.Type {
 91		case "reasoning.text":
 92			content = append(content, fantasy.ReasoningContent{
 93				Text:             detail.Text,
 94				ProviderMetadata: metadata,
 95			})
 96		case "reasoning.summary":
 97			content = append(content, fantasy.ReasoningContent{
 98				Text:             detail.Summary,
 99				ProviderMetadata: metadata,
100			})
101		case "reasoning.encrypted":
102			content = append(content, fantasy.ReasoningContent{
103				Text:             "[REDACTED]",
104				ProviderMetadata: metadata,
105			})
106		}
107	}
108	return content
109}
110
111func extractReasoningContext(ctx map[string]any) bool {
112	reasoningStarted, ok := ctx[reasoningStartedCtx]
113	if !ok {
114		return false
115	}
116	b, ok := reasoningStarted.(bool)
117	if !ok {
118		return false
119	}
120	return b
121}
122
123func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
124	if len(chunk.Choices) == 0 {
125		return ctx, true
126	}
127
128	reasoningStarted := extractReasoningContext(ctx)
129
130	for inx, choice := range chunk.Choices {
131		reasoningData := ReasoningData{}
132		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
133		if err != nil {
134			yield(fantasy.StreamPart{
135				Type:  fantasy.StreamPartTypeError,
136				Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err),
137			})
138			return ctx, false
139		}
140
141		emitEvent := func(reasoningContent string, signature string) bool {
142			if !reasoningStarted {
143				shouldContinue := yield(fantasy.StreamPart{
144					Type: fantasy.StreamPartTypeReasoningStart,
145					ID:   fmt.Sprintf("%d", inx),
146				})
147				if !shouldContinue {
148					return false
149				}
150			}
151
152			var metadata fantasy.ProviderMetadata
153
154			if signature != "" {
155				metadata = fantasy.ProviderMetadata{
156					Name: &ReasoningMetadata{
157						Signature: signature,
158					},
159					anthropic.Name: &anthropic.ReasoningOptionMetadata{
160						Signature: signature,
161					},
162				}
163			}
164
165			return yield(fantasy.StreamPart{
166				Type:             fantasy.StreamPartTypeReasoningDelta,
167				ID:               fmt.Sprintf("%d", inx),
168				Delta:            reasoningContent,
169				ProviderMetadata: metadata,
170			})
171		}
172		if len(reasoningData.ReasoningDetails) > 0 {
173			for _, detail := range reasoningData.ReasoningDetails {
174				if !reasoningStarted {
175					ctx[reasoningStartedCtx] = true
176				}
177				switch detail.Type {
178				case "reasoning.text":
179					return ctx, emitEvent(detail.Text, detail.Signature)
180				case "reasoning.summary":
181					return ctx, emitEvent(detail.Summary, detail.Signature)
182				case "reasoning.encrypted":
183					return ctx, emitEvent("[REDACTED]", detail.Signature)
184				}
185			}
186		} else if reasoningData.Reasoning != "" {
187			return ctx, emitEvent(reasoningData.Reasoning, "")
188		}
189		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
190			ctx[reasoningStartedCtx] = false
191			return ctx, yield(fantasy.StreamPart{
192				Type: fantasy.StreamPartTypeReasoningEnd,
193				ID:   fmt.Sprintf("%d", inx),
194			})
195		}
196	}
197	return ctx, true
198}
199
200func languageModelUsage(response openaisdk.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData) {
201	if len(response.Choices) == 0 {
202		return fantasy.Usage{}, nil
203	}
204	openrouterUsage := UsageAccounting{}
205	usage := response.Usage
206
207	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
208
209	completionTokenDetails := usage.CompletionTokensDetails
210	promptTokenDetails := usage.PromptTokensDetails
211
212	var provider string
213	if p, ok := response.JSON.ExtraFields["provider"]; ok {
214		provider = p.Raw()
215	}
216
217	// Build provider metadata
218	providerMetadata := &ProviderMetadata{
219		Provider: provider,
220		Usage:    openrouterUsage,
221	}
222
223	return fantasy.Usage{
224		InputTokens:     usage.PromptTokens,
225		OutputTokens:    usage.CompletionTokens,
226		TotalTokens:     usage.TotalTokens,
227		ReasoningTokens: completionTokenDetails.ReasoningTokens,
228		CacheReadTokens: promptTokenDetails.CachedTokens,
229	}, providerMetadata
230}
231
232func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata) {
233	usage := chunk.Usage
234	if usage.TotalTokens == 0 {
235		return fantasy.Usage{}, nil
236	}
237
238	streamProviderMetadata := &ProviderMetadata{}
239	if metadata != nil {
240		if providerMetadata, ok := metadata[Name]; ok {
241			converted, ok := providerMetadata.(*ProviderMetadata)
242			if ok {
243				streamProviderMetadata = converted
244			}
245		}
246	}
247	openrouterUsage := UsageAccounting{}
248	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
249	streamProviderMetadata.Usage = openrouterUsage
250
251	if p, ok := chunk.JSON.ExtraFields["provider"]; ok {
252		streamProviderMetadata.Provider = p.Raw()
253	}
254
255	// we do this here because the acc does not add prompt details
256	completionTokenDetails := usage.CompletionTokensDetails
257	promptTokenDetails := usage.PromptTokensDetails
258	aiUsage := fantasy.Usage{
259		InputTokens:     usage.PromptTokens,
260		OutputTokens:    usage.CompletionTokens,
261		TotalTokens:     usage.TotalTokens,
262		ReasoningTokens: completionTokenDetails.ReasoningTokens,
263		CacheReadTokens: promptTokenDetails.CachedTokens,
264	}
265
266	return aiUsage, fantasy.ProviderMetadata{
267		Name: streamProviderMetadata,
268	}
269}