diff --git a/internal/session/session.go b/internal/session/session.go index 6de6b9111d2f81fa49ae15e9ffaa9390f842d114..f6e3c8568adcf19e4eb8fc21db0ef1afd40cc4c5 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "strings" + "sync" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" @@ -84,6 +85,12 @@ type service struct { *pubsub.Broker[Session] db *sql.DB q *db.Queries + + // Estimated usage stays in memory so fetch-modify-save paths (e.g., + // updating todos or parent-session cost) do not rebuild a session from + // SQLite and incorrectly clear the UI "~" marker. + estimatedUsageMu sync.RWMutex + estimatedUsage map[string]bool } func (s *service) Create(ctx context.Context, title string) (Session, error) { @@ -155,6 +162,7 @@ func (s *service) Delete(ctx context.Context, id string) error { } session := s.fromDBItem(dbSession) + s.clearEstimatedUsageState(dbSession.ID) s.Publish(pubsub.DeletedEvent, session) event.SessionDeleted() return nil @@ -165,7 +173,9 @@ func (s *service) Get(ctx context.Context, id string) (Session, error) { if err != nil { return Session{}, err } - return s.fromDBItem(dbSession), nil + session := s.fromDBItem(dbSession) + s.applyEstimatedUsageState(&session) + return session, nil } func (s *service) GetLast(ctx context.Context) (Session, error) { @@ -173,7 +183,9 @@ func (s *service) GetLast(ctx context.Context) (Session, error) { if err != nil { return Session{}, err } - return s.fromDBItem(dbSession), nil + session := s.fromDBItem(dbSession) + s.applyEstimatedUsageState(&session) + return session, nil } func (s *service) Save(ctx context.Context, session Session) (Session, error) { @@ -201,6 +213,7 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) { return Session{}, err } estimatedUsage := session.EstimatedUsage + s.setEstimatedUsageState(session.ID, estimatedUsage) session = s.fromDBItem(dbSession) session.EstimatedUsage = estimatedUsage s.Publish(pubsub.UpdatedEvent, session) @@ -236,11 +249,34 @@ func (s *service) List(ctx context.Context) ([]Session, error) { sessions := make([]Session, len(dbSessions)) for i, dbSession := range dbSessions { sessions[i] = s.fromDBItem(dbSession) + s.applyEstimatedUsageState(&sessions[i]) } return sessions, nil } -func (s service) fromDBItem(item db.Session) Session { +func (s *service) applyEstimatedUsageState(session *Session) { + s.estimatedUsageMu.RLock() + session.EstimatedUsage = s.estimatedUsage[session.ID] + s.estimatedUsageMu.RUnlock() +} + +func (s *service) setEstimatedUsageState(sessionID string, estimatedUsage bool) { + s.estimatedUsageMu.Lock() + defer s.estimatedUsageMu.Unlock() + if estimatedUsage { + s.estimatedUsage[sessionID] = true + return + } + delete(s.estimatedUsage, sessionID) +} + +func (s *service) clearEstimatedUsageState(sessionID string) { + s.estimatedUsageMu.Lock() + delete(s.estimatedUsage, sessionID) + s.estimatedUsageMu.Unlock() +} + +func (s *service) fromDBItem(item db.Session) Session { todos, err := unmarshalTodos(item.Todos.String) if err != nil { slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err) @@ -285,9 +321,10 @@ func unmarshalTodos(data string) ([]Todo, error) { func NewService(q *db.Queries, conn *sql.DB) Service { broker := pubsub.NewBroker[Session]() return &service{ - Broker: broker, - db: conn, - q: q, + Broker: broker, + db: conn, + q: q, + estimatedUsage: make(map[string]bool), } } diff --git a/internal/session/session_test.go b/internal/session/session_test.go new file mode 100644 index 0000000000000000000000000000000000000000..50af7e23c130f5dda5627c2af29558cdec11799b --- /dev/null +++ b/internal/session/session_test.go @@ -0,0 +1,81 @@ +package session + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/db" + "github.com/stretchr/testify/require" +) + +func TestEstimatedUsageStateSurvivesFetchModifySave(t *testing.T) { + dataDir := t.TempDir() + t.Cleanup(func() { + require.NoError(t, db.Release(dataDir)) + db.ResetPool() + }) + + conn, err := db.Connect(t.Context(), dataDir) + require.NoError(t, err) + + sessions := NewService(db.New(conn), conn) + + created, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + created.PromptTokens = 100 + created.CompletionTokens = 50 + created.EstimatedUsage = true + + saved, err := sessions.Save(t.Context(), created) + require.NoError(t, err) + require.True(t, saved.EstimatedUsage) + + fetched, err := sessions.Get(t.Context(), created.ID) + require.NoError(t, err) + require.True(t, fetched.EstimatedUsage) + + fetched.Todos = []Todo{{ + Content: "Check estimate state", + Status: TodoStatusInProgress, + ActiveForm: "Checking estimate state", + }} + + updated, err := sessions.Save(t.Context(), fetched) + require.NoError(t, err) + require.True(t, updated.EstimatedUsage) + + refetched, err := sessions.Get(t.Context(), created.ID) + require.NoError(t, err) + require.True(t, refetched.EstimatedUsage) +} + +func TestEstimatedUsageStateCanBeClearedByExplicitSave(t *testing.T) { + dataDir := t.TempDir() + t.Cleanup(func() { + require.NoError(t, db.Release(dataDir)) + db.ResetPool() + }) + + conn, err := db.Connect(t.Context(), dataDir) + require.NoError(t, err) + + sessions := NewService(db.New(conn), conn) + + created, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + created.PromptTokens = 100 + created.CompletionTokens = 50 + created.EstimatedUsage = true + + saved, err := sessions.Save(t.Context(), created) + require.NoError(t, err) + require.True(t, saved.EstimatedUsage) + + saved.EstimatedUsage = false + updated, err := sessions.Save(t.Context(), saved) + require.NoError(t, err) + require.False(t, updated.EstimatedUsage) + + refetched, err := sessions.Get(t.Context(), created.ID) + require.NoError(t, err) + require.False(t, refetched.EstimatedUsage) +} diff --git a/internal/ui/styles/quickstyle.go b/internal/ui/styles/quickstyle.go index cf685cf66f65a1abc388337a47e35ecbccc61c3d..b631dceff8041af6844586a1cf4585bd140c8b41 100644 --- a/internal/ui/styles/quickstyle.go +++ b/internal/ui/styles/quickstyle.go @@ -763,7 +763,7 @@ func quickStyle(o quickStyleOpts) Styles { s.ModelInfo.Reasoning = lipgloss.NewStyle().Foreground(o.fgMostSubtle).PaddingLeft(2) s.ModelInfo.TokenCount = lipgloss.NewStyle().Foreground(o.fgMostSubtle) s.ModelInfo.TokenPercentage = lipgloss.NewStyle().Foreground(o.fgMoreSubtle) - s.ModelInfo.EstimatedUsagePrefix = lipgloss.NewStyle().Foreground(o.fgBase) + s.ModelInfo.EstimatedUsagePrefix = s.ModelInfo.TokenPercentage s.ModelInfo.Cost = lipgloss.NewStyle().Foreground(o.fgMoreSubtle) s.ModelInfo.HypercreditIcon = lipgloss.NewStyle().Foreground(charmtone.Dolly) s.ModelInfo.HypercreditText = lipgloss.NewStyle().Foreground(o.fgMoreSubtle)