session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"sync"
 11
 12	"github.com/charmbracelet/crush/internal/db"
 13	"github.com/charmbracelet/crush/internal/event"
 14	"github.com/charmbracelet/crush/internal/pubsub"
 15	"github.com/google/uuid"
 16	"github.com/zeebo/xxh3"
 17)
 18
 19type TodoStatus string
 20
 21const (
 22	TodoStatusPending    TodoStatus = "pending"
 23	TodoStatusInProgress TodoStatus = "in_progress"
 24	TodoStatusCompleted  TodoStatus = "completed"
 25)
 26
 27// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
 28func HashID(id string) string {
 29	h := xxh3.New()
 30	h.WriteString(id)
 31	return fmt.Sprintf("%x", h.Sum(nil))
 32}
 33
 34type Todo struct {
 35	Content    string     `json:"content"`
 36	Status     TodoStatus `json:"status"`
 37	ActiveForm string     `json:"active_form"`
 38}
 39
 40// HasIncompleteTodos returns true if there are any non-completed todos.
 41func HasIncompleteTodos(todos []Todo) bool {
 42	for _, todo := range todos {
 43		if todo.Status != TodoStatusCompleted {
 44			return true
 45		}
 46	}
 47	return false
 48}
 49
 50type Session struct {
 51	ID               string
 52	ParentSessionID  string
 53	Title            string
 54	MessageCount     int64
 55	PromptTokens     int64
 56	CompletionTokens int64
 57	EstimatedUsage   bool
 58	SummaryMessageID string
 59	Cost             float64
 60	Todos            []Todo
 61	CreatedAt        int64
 62	UpdatedAt        int64
 63}
 64
 65type Service interface {
 66	pubsub.Subscriber[Session]
 67	Create(ctx context.Context, title string) (Session, error)
 68	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 69	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 70	Get(ctx context.Context, id string) (Session, error)
 71	GetLast(ctx context.Context) (Session, error)
 72	List(ctx context.Context) ([]Session, error)
 73	Save(ctx context.Context, session Session) (Session, error)
 74	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 75	Rename(ctx context.Context, id string, title string) error
 76	Delete(ctx context.Context, id string) error
 77
 78	// Agent tool session management
 79	CreateAgentToolSessionID(messageID, toolCallID string) string
 80	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 81	IsAgentToolSession(sessionID string) bool
 82}
 83
 84type service struct {
 85	*pubsub.Broker[Session]
 86	db *sql.DB
 87	q  *db.Queries
 88
 89	// Estimated usage stays in memory so fetch-modify-save paths (e.g.,
 90	// updating todos or parent-session cost) do not rebuild a session from
 91	// SQLite and incorrectly clear the UI "~" marker.
 92	estimatedUsageMu sync.RWMutex
 93	estimatedUsage   map[string]bool
 94}
 95
 96func (s *service) Create(ctx context.Context, title string) (Session, error) {
 97	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 98		ID:    uuid.New().String(),
 99		Title: title,
100	})
101	if err != nil {
102		return Session{}, err
103	}
104	session := s.fromDBItem(dbSession)
105	s.Publish(pubsub.CreatedEvent, session)
106	event.SessionCreated()
107	return session, nil
108}
109
110func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
111	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
112		ID:              toolCallID,
113		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
114		Title:           title,
115	})
116	if err != nil {
117		return Session{}, err
118	}
119	session := s.fromDBItem(dbSession)
120	s.Publish(pubsub.CreatedEvent, session)
121	return session, nil
122}
123
124func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
125	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
126		ID:              "title-" + parentSessionID,
127		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
128		Title:           "Generate a title",
129	})
130	if err != nil {
131		return Session{}, err
132	}
133	session := s.fromDBItem(dbSession)
134	s.Publish(pubsub.CreatedEvent, session)
135	return session, nil
136}
137
138func (s *service) Delete(ctx context.Context, id string) error {
139	tx, err := s.db.BeginTx(ctx, nil)
140	if err != nil {
141		return fmt.Errorf("beginning transaction: %w", err)
142	}
143	defer tx.Rollback() //nolint:errcheck
144
145	qtx := s.q.WithTx(tx)
146
147	dbSession, err := qtx.GetSessionByID(ctx, id)
148	if err != nil {
149		return err
150	}
151	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
152		return fmt.Errorf("deleting session messages: %w", err)
153	}
154	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
155		return fmt.Errorf("deleting session files: %w", err)
156	}
157	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
158		return fmt.Errorf("deleting session: %w", err)
159	}
160	if err = tx.Commit(); err != nil {
161		return fmt.Errorf("committing transaction: %w", err)
162	}
163
164	session := s.fromDBItem(dbSession)
165	s.clearEstimatedUsageState(dbSession.ID)
166	s.Publish(pubsub.DeletedEvent, session)
167	event.SessionDeleted()
168	return nil
169}
170
171func (s *service) Get(ctx context.Context, id string) (Session, error) {
172	dbSession, err := s.q.GetSessionByID(ctx, id)
173	if err != nil {
174		return Session{}, err
175	}
176	session := s.fromDBItem(dbSession)
177	s.applyEstimatedUsageState(&session)
178	return session, nil
179}
180
181func (s *service) GetLast(ctx context.Context) (Session, error) {
182	dbSession, err := s.q.GetLastSession(ctx)
183	if err != nil {
184		return Session{}, err
185	}
186	session := s.fromDBItem(dbSession)
187	s.applyEstimatedUsageState(&session)
188	return session, nil
189}
190
191func (s *service) Save(ctx context.Context, session Session) (Session, error) {
192	todosJSON, err := marshalTodos(session.Todos)
193	if err != nil {
194		return Session{}, err
195	}
196
197	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
198		ID:               session.ID,
199		Title:            session.Title,
200		PromptTokens:     session.PromptTokens,
201		CompletionTokens: session.CompletionTokens,
202		SummaryMessageID: sql.NullString{
203			String: session.SummaryMessageID,
204			Valid:  session.SummaryMessageID != "",
205		},
206		Cost: session.Cost,
207		Todos: sql.NullString{
208			String: todosJSON,
209			Valid:  todosJSON != "",
210		},
211	})
212	if err != nil {
213		return Session{}, err
214	}
215	estimatedUsage := session.EstimatedUsage
216	s.setEstimatedUsageState(session.ID, estimatedUsage)
217	session = s.fromDBItem(dbSession)
218	session.EstimatedUsage = estimatedUsage
219	s.Publish(pubsub.UpdatedEvent, session)
220	return session, nil
221}
222
223// UpdateTitleAndUsage updates only the title and usage fields atomically.
224// This is safer than fetching, modifying, and saving the entire session.
225func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
226	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
227		ID:               sessionID,
228		Title:            title,
229		PromptTokens:     promptTokens,
230		CompletionTokens: completionTokens,
231		Cost:             cost,
232	})
233}
234
235// Rename updates only the title of a session without touching updated_at or
236// usage fields.
237func (s *service) Rename(ctx context.Context, id string, title string) error {
238	return s.q.RenameSession(ctx, db.RenameSessionParams{
239		ID:    id,
240		Title: title,
241	})
242}
243
244func (s *service) List(ctx context.Context) ([]Session, error) {
245	dbSessions, err := s.q.ListSessions(ctx)
246	if err != nil {
247		return nil, err
248	}
249	sessions := make([]Session, len(dbSessions))
250	for i, dbSession := range dbSessions {
251		sessions[i] = s.fromDBItem(dbSession)
252		s.applyEstimatedUsageState(&sessions[i])
253	}
254	return sessions, nil
255}
256
257func (s *service) applyEstimatedUsageState(session *Session) {
258	s.estimatedUsageMu.RLock()
259	session.EstimatedUsage = s.estimatedUsage[session.ID]
260	s.estimatedUsageMu.RUnlock()
261}
262
263func (s *service) setEstimatedUsageState(sessionID string, estimatedUsage bool) {
264	s.estimatedUsageMu.Lock()
265	defer s.estimatedUsageMu.Unlock()
266	if estimatedUsage {
267		s.estimatedUsage[sessionID] = true
268		return
269	}
270	delete(s.estimatedUsage, sessionID)
271}
272
273func (s *service) clearEstimatedUsageState(sessionID string) {
274	s.estimatedUsageMu.Lock()
275	delete(s.estimatedUsage, sessionID)
276	s.estimatedUsageMu.Unlock()
277}
278
279func (s *service) fromDBItem(item db.Session) Session {
280	todos, err := unmarshalTodos(item.Todos.String)
281	if err != nil {
282		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
283	}
284	return Session{
285		ID:               item.ID,
286		ParentSessionID:  item.ParentSessionID.String,
287		Title:            item.Title,
288		MessageCount:     item.MessageCount,
289		PromptTokens:     item.PromptTokens,
290		CompletionTokens: item.CompletionTokens,
291		SummaryMessageID: item.SummaryMessageID.String,
292		Cost:             item.Cost,
293		Todos:            todos,
294		CreatedAt:        item.CreatedAt,
295		UpdatedAt:        item.UpdatedAt,
296	}
297}
298
299func marshalTodos(todos []Todo) (string, error) {
300	if len(todos) == 0 {
301		return "", nil
302	}
303	data, err := json.Marshal(todos)
304	if err != nil {
305		return "", err
306	}
307	return string(data), nil
308}
309
310func unmarshalTodos(data string) ([]Todo, error) {
311	if data == "" {
312		return []Todo{}, nil
313	}
314	var todos []Todo
315	if err := json.Unmarshal([]byte(data), &todos); err != nil {
316		return []Todo{}, err
317	}
318	return todos, nil
319}
320
321func NewService(q *db.Queries, conn *sql.DB) Service {
322	broker := pubsub.NewBroker[Session]()
323	return &service{
324		Broker:         broker,
325		db:             conn,
326		q:              q,
327		estimatedUsage: make(map[string]bool),
328	}
329}
330
331// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
332func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
333	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
334}
335
336// ParseAgentToolSessionID parses an agent tool session ID into its components
337func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
338	parts := strings.Split(sessionID, "$$")
339	if len(parts) != 2 {
340		return "", "", false
341	}
342	return parts[0], parts[1], true
343}
344
345// IsAgentToolSession checks if a session ID follows the agent tool session format
346func (s *service) IsAgentToolSession(sessionID string) bool {
347	_, _, ok := s.ParseAgentToolSessionID(sessionID)
348	return ok
349}