@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"strings"
+ "sync"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/event"
@@ -84,6 +85,12 @@ type service struct {
*pubsub.Broker[Session]
db *sql.DB
q *db.Queries
+
+ // Estimated usage stays in memory so fetch-modify-save paths (e.g.,
+ // updating todos or parent-session cost) do not rebuild a session from
+ // SQLite and incorrectly clear the UI "~" marker.
+ estimatedUsageMu sync.RWMutex
+ estimatedUsage map[string]bool
}
func (s *service) Create(ctx context.Context, title string) (Session, error) {
@@ -155,6 +162,7 @@ func (s *service) Delete(ctx context.Context, id string) error {
}
session := s.fromDBItem(dbSession)
+ s.clearEstimatedUsageState(dbSession.ID)
s.Publish(pubsub.DeletedEvent, session)
event.SessionDeleted()
return nil
@@ -165,7 +173,9 @@ func (s *service) Get(ctx context.Context, id string) (Session, error) {
if err != nil {
return Session{}, err
}
- return s.fromDBItem(dbSession), nil
+ session := s.fromDBItem(dbSession)
+ s.applyEstimatedUsageState(&session)
+ return session, nil
}
func (s *service) GetLast(ctx context.Context) (Session, error) {
@@ -173,7 +183,9 @@ func (s *service) GetLast(ctx context.Context) (Session, error) {
if err != nil {
return Session{}, err
}
- return s.fromDBItem(dbSession), nil
+ session := s.fromDBItem(dbSession)
+ s.applyEstimatedUsageState(&session)
+ return session, nil
}
func (s *service) Save(ctx context.Context, session Session) (Session, error) {
@@ -201,6 +213,7 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
return Session{}, err
}
estimatedUsage := session.EstimatedUsage
+ s.setEstimatedUsageState(session.ID, estimatedUsage)
session = s.fromDBItem(dbSession)
session.EstimatedUsage = estimatedUsage
s.Publish(pubsub.UpdatedEvent, session)
@@ -236,11 +249,34 @@ func (s *service) List(ctx context.Context) ([]Session, error) {
sessions := make([]Session, len(dbSessions))
for i, dbSession := range dbSessions {
sessions[i] = s.fromDBItem(dbSession)
+ s.applyEstimatedUsageState(&sessions[i])
}
return sessions, nil
}
-func (s service) fromDBItem(item db.Session) Session {
+func (s *service) applyEstimatedUsageState(session *Session) {
+ s.estimatedUsageMu.RLock()
+ session.EstimatedUsage = s.estimatedUsage[session.ID]
+ s.estimatedUsageMu.RUnlock()
+}
+
+func (s *service) setEstimatedUsageState(sessionID string, estimatedUsage bool) {
+ s.estimatedUsageMu.Lock()
+ defer s.estimatedUsageMu.Unlock()
+ if estimatedUsage {
+ s.estimatedUsage[sessionID] = true
+ return
+ }
+ delete(s.estimatedUsage, sessionID)
+}
+
+func (s *service) clearEstimatedUsageState(sessionID string) {
+ s.estimatedUsageMu.Lock()
+ delete(s.estimatedUsage, sessionID)
+ s.estimatedUsageMu.Unlock()
+}
+
+func (s *service) fromDBItem(item db.Session) Session {
todos, err := unmarshalTodos(item.Todos.String)
if err != nil {
slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
@@ -285,9 +321,10 @@ func unmarshalTodos(data string) ([]Todo, error) {
func NewService(q *db.Queries, conn *sql.DB) Service {
broker := pubsub.NewBroker[Session]()
return &service{
- Broker: broker,
- db: conn,
- q: q,
+ Broker: broker,
+ db: conn,
+ q: q,
+ estimatedUsage: make(map[string]bool),
}
}
@@ -0,0 +1,81 @@
+package session
+
+import (
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/db"
+ "github.com/stretchr/testify/require"
+)
+
+func TestEstimatedUsageStateSurvivesFetchModifySave(t *testing.T) {
+ dataDir := t.TempDir()
+ t.Cleanup(func() {
+ require.NoError(t, db.Release(dataDir))
+ db.ResetPool()
+ })
+
+ conn, err := db.Connect(t.Context(), dataDir)
+ require.NoError(t, err)
+
+ sessions := NewService(db.New(conn), conn)
+
+ created, err := sessions.Create(t.Context(), "test")
+ require.NoError(t, err)
+ created.PromptTokens = 100
+ created.CompletionTokens = 50
+ created.EstimatedUsage = true
+
+ saved, err := sessions.Save(t.Context(), created)
+ require.NoError(t, err)
+ require.True(t, saved.EstimatedUsage)
+
+ fetched, err := sessions.Get(t.Context(), created.ID)
+ require.NoError(t, err)
+ require.True(t, fetched.EstimatedUsage)
+
+ fetched.Todos = []Todo{{
+ Content: "Check estimate state",
+ Status: TodoStatusInProgress,
+ ActiveForm: "Checking estimate state",
+ }}
+
+ updated, err := sessions.Save(t.Context(), fetched)
+ require.NoError(t, err)
+ require.True(t, updated.EstimatedUsage)
+
+ refetched, err := sessions.Get(t.Context(), created.ID)
+ require.NoError(t, err)
+ require.True(t, refetched.EstimatedUsage)
+}
+
+func TestEstimatedUsageStateCanBeClearedByExplicitSave(t *testing.T) {
+ dataDir := t.TempDir()
+ t.Cleanup(func() {
+ require.NoError(t, db.Release(dataDir))
+ db.ResetPool()
+ })
+
+ conn, err := db.Connect(t.Context(), dataDir)
+ require.NoError(t, err)
+
+ sessions := NewService(db.New(conn), conn)
+
+ created, err := sessions.Create(t.Context(), "test")
+ require.NoError(t, err)
+ created.PromptTokens = 100
+ created.CompletionTokens = 50
+ created.EstimatedUsage = true
+
+ saved, err := sessions.Save(t.Context(), created)
+ require.NoError(t, err)
+ require.True(t, saved.EstimatedUsage)
+
+ saved.EstimatedUsage = false
+ updated, err := sessions.Save(t.Context(), saved)
+ require.NoError(t, err)
+ require.False(t, updated.EstimatedUsage)
+
+ refetched, err := sessions.Get(t.Context(), created.ID)
+ require.NoError(t, err)
+ require.False(t, refetched.EstimatedUsage)
+}