@@ -760,7 +760,9 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
// Just in case, get just the last usage info.
usage := resp.Response.Usage
currentSession.SummaryMessageID = summaryMessage.ID
- currentSession.CompletionTokens = usage.OutputTokens
+ if completionTokens := summaryCompletionTokens(usage, summaryMessage); completionTokens != 0 {
+ currentSession.CompletionTokens = completionTokens
+ }
currentSession.PromptTokens = 0
_, err = a.sessions.Save(genCtx, currentSession)
if err != nil {
@@ -1145,11 +1147,9 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
- eventCost := cost
- if estimated {
- eventCost = 0
+ if !estimated {
+ a.eventTokensUsed(session.ID, model, usage, cost)
}
- a.eventTokensUsed(session.ID, model, usage, eventCost)
if estimated {
cost = 0
@@ -1166,10 +1166,23 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
}
session.Cost += cost
- if !usageIsZero(usage) {
+ updateSessionTokenCounters(session, usage)
+}
+
+func updateSessionTokenCounters(session *session.Session, usage fantasy.Usage) {
+ if usage.OutputTokens != 0 {
session.CompletionTokens = usage.OutputTokens
- session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
}
+ if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
+ session.PromptTokens = promptTokens
+ }
+}
+
+func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message) int64 {
+ if usage.OutputTokens != 0 {
+ return usage.OutputTokens
+ }
+ return approxTokenCount(summaryMessage.Content().Text) + approxTokenCount(summaryMessage.ReasoningContent().String())
}
func (a *sessionAgent) Cancel(sessionID string) {
@@ -6,6 +6,7 @@ import (
"charm.land/catwalk/pkg/catwalk"
"charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
"github.com/stretchr/testify/require"
)
@@ -232,7 +233,7 @@ func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
require.Equal(t, int64(456), currentSession.CompletionTokens)
}
-func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) {
+func TestUpdateSessionUsagePreservesOmittedCountersForPartialUsage(t *testing.T) {
t.Parallel()
agent := &sessionAgent{}
@@ -247,7 +248,77 @@ func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) {
agent.updateSessionUsage(model, currentSession, usage, nil, false)
require.Equal(t, int64(789), currentSession.PromptTokens)
- require.Zero(t, currentSession.CompletionTokens)
+ require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsagePreservesCountersForTotalOnlyUsage(t *testing.T) {
+ t.Parallel()
+
+ agent := &sessionAgent{}
+ currentSession := &session.Session{
+ ID: "session-id",
+ PromptTokens: 123,
+ CompletionTokens: 456,
+ }
+ model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+ usage := fantasy.Usage{TotalTokens: 100}
+
+ agent.updateSessionUsage(model, currentSession, usage, nil, false)
+
+ require.Equal(t, int64(123), currentSession.PromptTokens)
+ require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsagePreservesPromptForOutputOnlyUsage(t *testing.T) {
+ t.Parallel()
+
+ agent := &sessionAgent{}
+ currentSession := &session.Session{
+ ID: "session-id",
+ PromptTokens: 123,
+ CompletionTokens: 456,
+ }
+ model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+ usage := fantasy.Usage{OutputTokens: 50}
+
+ agent.updateSessionUsage(model, currentSession, usage, nil, false)
+
+ require.Equal(t, int64(123), currentSession.PromptTokens)
+ require.Equal(t, int64(50), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsageKeepsCountersForEstimatedZeroUsage(t *testing.T) {
+ t.Parallel()
+
+ agent := &sessionAgent{}
+ currentSession := &session.Session{
+ ID: "session-id",
+ PromptTokens: 123,
+ CompletionTokens: 456,
+ Cost: 1.25,
+ }
+ model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+
+ agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, true)
+
+ require.Equal(t, 1.25, currentSession.Cost)
+ require.Equal(t, int64(123), currentSession.PromptTokens)
+ require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestSummaryCompletionTokens(t *testing.T) {
+ t.Parallel()
+
+ summaryMessage := message.Message{
+ Parts: []message.ContentPart{
+ message.TextContent{Text: "summary text"},
+ message.ReasoningContent{Thinking: "reasoning text"},
+ },
+ }
+
+ require.Equal(t, int64(42), summaryCompletionTokens(fantasy.Usage{OutputTokens: 42}, summaryMessage))
+ require.Equal(t, approxTokenCount("summary text")+approxTokenCount("reasoning text"), summaryCompletionTokens(fantasy.Usage{}, summaryMessage))
+ require.Zero(t, summaryCompletionTokens(fantasy.Usage{}, message.Message{}))
}
func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {