fix(session): preserve estimated usage marker

Greg Slepak created

Keep estimated usage state in memory across session fetch-modify-save updates so unrelated saves do not clear the UI marker, and align the marker color with context percentages.

Change summary

internal/session/session.go      | 49 ++++++++++++++++++--
internal/session/session_test.go | 81 ++++++++++++++++++++++++++++++++++
internal/ui/styles/quickstyle.go |  2 
3 files changed, 125 insertions(+), 7 deletions(-)

Detailed changes

internal/session/session.go 🔗

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"log/slog"
 	"strings"
+	"sync"
 
 	"github.com/charmbracelet/crush/internal/db"
 	"github.com/charmbracelet/crush/internal/event"
@@ -84,6 +85,12 @@ type service struct {
 	*pubsub.Broker[Session]
 	db *sql.DB
 	q  *db.Queries
+
+	// Estimated usage stays in memory so fetch-modify-save paths (e.g.,
+	// updating todos or parent-session cost) do not rebuild a session from
+	// SQLite and incorrectly clear the UI "~" marker.
+	estimatedUsageMu sync.RWMutex
+	estimatedUsage   map[string]bool
 }
 
 func (s *service) Create(ctx context.Context, title string) (Session, error) {
@@ -155,6 +162,7 @@ func (s *service) Delete(ctx context.Context, id string) error {
 	}
 
 	session := s.fromDBItem(dbSession)
+	s.clearEstimatedUsageState(dbSession.ID)
 	s.Publish(pubsub.DeletedEvent, session)
 	event.SessionDeleted()
 	return nil
@@ -165,7 +173,9 @@ func (s *service) Get(ctx context.Context, id string) (Session, error) {
 	if err != nil {
 		return Session{}, err
 	}
-	return s.fromDBItem(dbSession), nil
+	session := s.fromDBItem(dbSession)
+	s.applyEstimatedUsageState(&session)
+	return session, nil
 }
 
 func (s *service) GetLast(ctx context.Context) (Session, error) {
@@ -173,7 +183,9 @@ func (s *service) GetLast(ctx context.Context) (Session, error) {
 	if err != nil {
 		return Session{}, err
 	}
-	return s.fromDBItem(dbSession), nil
+	session := s.fromDBItem(dbSession)
+	s.applyEstimatedUsageState(&session)
+	return session, nil
 }
 
 func (s *service) Save(ctx context.Context, session Session) (Session, error) {
@@ -201,6 +213,7 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
 		return Session{}, err
 	}
 	estimatedUsage := session.EstimatedUsage
+	s.setEstimatedUsageState(session.ID, estimatedUsage)
 	session = s.fromDBItem(dbSession)
 	session.EstimatedUsage = estimatedUsage
 	s.Publish(pubsub.UpdatedEvent, session)
@@ -236,11 +249,34 @@ func (s *service) List(ctx context.Context) ([]Session, error) {
 	sessions := make([]Session, len(dbSessions))
 	for i, dbSession := range dbSessions {
 		sessions[i] = s.fromDBItem(dbSession)
+		s.applyEstimatedUsageState(&sessions[i])
 	}
 	return sessions, nil
 }
 
-func (s service) fromDBItem(item db.Session) Session {
+func (s *service) applyEstimatedUsageState(session *Session) {
+	s.estimatedUsageMu.RLock()
+	session.EstimatedUsage = s.estimatedUsage[session.ID]
+	s.estimatedUsageMu.RUnlock()
+}
+
+func (s *service) setEstimatedUsageState(sessionID string, estimatedUsage bool) {
+	s.estimatedUsageMu.Lock()
+	defer s.estimatedUsageMu.Unlock()
+	if estimatedUsage {
+		s.estimatedUsage[sessionID] = true
+		return
+	}
+	delete(s.estimatedUsage, sessionID)
+}
+
+func (s *service) clearEstimatedUsageState(sessionID string) {
+	s.estimatedUsageMu.Lock()
+	delete(s.estimatedUsage, sessionID)
+	s.estimatedUsageMu.Unlock()
+}
+
+func (s *service) fromDBItem(item db.Session) Session {
 	todos, err := unmarshalTodos(item.Todos.String)
 	if err != nil {
 		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
@@ -285,9 +321,10 @@ func unmarshalTodos(data string) ([]Todo, error) {
 func NewService(q *db.Queries, conn *sql.DB) Service {
 	broker := pubsub.NewBroker[Session]()
 	return &service{
-		Broker: broker,
-		db:     conn,
-		q:      q,
+		Broker:         broker,
+		db:             conn,
+		q:              q,
+		estimatedUsage: make(map[string]bool),
 	}
 }
 

internal/session/session_test.go 🔗

@@ -0,0 +1,81 @@
+package session
+
+import (
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/db"
+	"github.com/stretchr/testify/require"
+)
+
+func TestEstimatedUsageStateSurvivesFetchModifySave(t *testing.T) {
+	dataDir := t.TempDir()
+	t.Cleanup(func() {
+		require.NoError(t, db.Release(dataDir))
+		db.ResetPool()
+	})
+
+	conn, err := db.Connect(t.Context(), dataDir)
+	require.NoError(t, err)
+
+	sessions := NewService(db.New(conn), conn)
+
+	created, err := sessions.Create(t.Context(), "test")
+	require.NoError(t, err)
+	created.PromptTokens = 100
+	created.CompletionTokens = 50
+	created.EstimatedUsage = true
+
+	saved, err := sessions.Save(t.Context(), created)
+	require.NoError(t, err)
+	require.True(t, saved.EstimatedUsage)
+
+	fetched, err := sessions.Get(t.Context(), created.ID)
+	require.NoError(t, err)
+	require.True(t, fetched.EstimatedUsage)
+
+	fetched.Todos = []Todo{{
+		Content:    "Check estimate state",
+		Status:     TodoStatusInProgress,
+		ActiveForm: "Checking estimate state",
+	}}
+
+	updated, err := sessions.Save(t.Context(), fetched)
+	require.NoError(t, err)
+	require.True(t, updated.EstimatedUsage)
+
+	refetched, err := sessions.Get(t.Context(), created.ID)
+	require.NoError(t, err)
+	require.True(t, refetched.EstimatedUsage)
+}
+
+func TestEstimatedUsageStateCanBeClearedByExplicitSave(t *testing.T) {
+	dataDir := t.TempDir()
+	t.Cleanup(func() {
+		require.NoError(t, db.Release(dataDir))
+		db.ResetPool()
+	})
+
+	conn, err := db.Connect(t.Context(), dataDir)
+	require.NoError(t, err)
+
+	sessions := NewService(db.New(conn), conn)
+
+	created, err := sessions.Create(t.Context(), "test")
+	require.NoError(t, err)
+	created.PromptTokens = 100
+	created.CompletionTokens = 50
+	created.EstimatedUsage = true
+
+	saved, err := sessions.Save(t.Context(), created)
+	require.NoError(t, err)
+	require.True(t, saved.EstimatedUsage)
+
+	saved.EstimatedUsage = false
+	updated, err := sessions.Save(t.Context(), saved)
+	require.NoError(t, err)
+	require.False(t, updated.EstimatedUsage)
+
+	refetched, err := sessions.Get(t.Context(), created.ID)
+	require.NoError(t, err)
+	require.False(t, refetched.EstimatedUsage)
+}

internal/ui/styles/quickstyle.go 🔗

@@ -763,7 +763,7 @@ func quickStyle(o quickStyleOpts) Styles {
 	s.ModelInfo.Reasoning = lipgloss.NewStyle().Foreground(o.fgMostSubtle).PaddingLeft(2)
 	s.ModelInfo.TokenCount = lipgloss.NewStyle().Foreground(o.fgMostSubtle)
 	s.ModelInfo.TokenPercentage = lipgloss.NewStyle().Foreground(o.fgMoreSubtle)
-	s.ModelInfo.EstimatedUsagePrefix = lipgloss.NewStyle().Foreground(o.fgBase)
+	s.ModelInfo.EstimatedUsagePrefix = s.ModelInfo.TokenPercentage
 	s.ModelInfo.Cost = lipgloss.NewStyle().Foreground(o.fgMoreSubtle)
 	s.ModelInfo.HypercreditIcon = lipgloss.NewStyle().Foreground(charmtone.Dolly)
 	s.ModelInfo.HypercreditText = lipgloss.NewStyle().Foreground(o.fgMoreSubtle)