fix: cost calculation when using openrouter

Kujtim Hoxha created

Change summary

internal/agent/agent.go | 54 +++++++++++++++++++++++++++++++++++++++---
1 file changed, 49 insertions(+), 5 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -17,6 +17,7 @@ import (
 	"charm.land/fantasy/providers/bedrock"
 	"charm.land/fantasy/providers/google"
 	"charm.land/fantasy/providers/openai"
+	"charm.land/fantasy/providers/openrouter"
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/agent/tools"
 	"github.com/charmbracelet/crush/internal/config"
@@ -349,7 +350,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				finishReason = message.FinishReasonToolUse
 			}
 			currentAssistant.AddFinish(finishReason, "", "")
-			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
+			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
 			sessionLock.Lock()
 			_, sessionErr := a.sessions.Save(genCtx, currentSession)
 			sessionLock.Unlock()
@@ -562,7 +563,19 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 		return err
 	}
 
-	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
+	var openrouterCost *float64
+	for _, step := range resp.Steps {
+		stepCost := a.openrouterCost(step.ProviderMetadata)
+		if stepCost != nil {
+			newCost := *stepCost
+			if openrouterCost != nil {
+				newCost += *openrouterCost
+			}
+			openrouterCost = &newCost
+		}
+	}
+
+	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
 
 	// just in case get just the last usage
 	usage := resp.Response.Usage
@@ -690,7 +703,20 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
 	}
 
 	session.Title = title
-	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
+
+	var openrouterCost *float64
+	for _, step := range resp.Steps {
+		stepCost := a.openrouterCost(step.ProviderMetadata)
+		if stepCost != nil {
+			newCost := *stepCost
+			if openrouterCost != nil {
+				newCost += *openrouterCost
+			}
+			openrouterCost = &newCost
+		}
+	}
+
+	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
 	_, saveErr := a.sessions.Save(ctx, *session)
 	if saveErr != nil {
 		slog.Error("failed to save session title & usage", "error", saveErr)
@@ -698,7 +724,20 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
 	}
 }
 
-func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
+func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
+	openrouterMetadata, ok := metadata[openrouter.Name]
+	if !ok {
+		return nil
+	}
+
+	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
+	if !ok {
+		return nil
+	}
+	return &opts.Usage.Cost
+}
+
+func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
 	modelConfig := model.CatwalkCfg
 	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
@@ -707,7 +746,12 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
 
 	a.eventTokensUsed(session.ID, model, usage, cost)
 
-	session.Cost += cost
+	if overrideCost != nil {
+		session.Cost += *overrideCost
+	} else {
+		session.Cost += cost
+	}
+
 	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 }