session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/google/uuid"
 15	"github.com/zeebo/xxh3"
 16)
 17
 18type TodoStatus string
 19
 20const (
 21	TodoStatusPending    TodoStatus = "pending"
 22	TodoStatusInProgress TodoStatus = "in_progress"
 23	TodoStatusCompleted  TodoStatus = "completed"
 24)
 25
 26// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
 27func HashID(id string) string {
 28	h := xxh3.New()
 29	h.WriteString(id)
 30	return fmt.Sprintf("%x", h.Sum(nil))
 31}
 32
 33type Todo struct {
 34	Content    string     `json:"content"`
 35	Status     TodoStatus `json:"status"`
 36	ActiveForm string     `json:"active_form"`
 37}
 38
 39// HasIncompleteTodos returns true if there are any non-completed todos.
 40func HasIncompleteTodos(todos []Todo) bool {
 41	for _, todo := range todos {
 42		if todo.Status != TodoStatusCompleted {
 43			return true
 44		}
 45	}
 46	return false
 47}
 48
 49type Session struct {
 50	ID               string
 51	ParentSessionID  string
 52	Title            string
 53	MessageCount     int64
 54	PromptTokens     int64
 55	CompletionTokens int64
 56	EstimatedUsage   bool
 57	SummaryMessageID string
 58	Cost             float64
 59	Todos            []Todo
 60	CreatedAt        int64
 61	UpdatedAt        int64
 62}
 63
 64type Service interface {
 65	pubsub.Subscriber[Session]
 66	Create(ctx context.Context, title string) (Session, error)
 67	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 68	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 69	Get(ctx context.Context, id string) (Session, error)
 70	GetLast(ctx context.Context) (Session, error)
 71	List(ctx context.Context) ([]Session, error)
 72	Save(ctx context.Context, session Session) (Session, error)
 73	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 74	Rename(ctx context.Context, id string, title string) error
 75	Delete(ctx context.Context, id string) error
 76
 77	// Agent tool session management
 78	CreateAgentToolSessionID(messageID, toolCallID string) string
 79	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 80	IsAgentToolSession(sessionID string) bool
 81}
 82
 83type service struct {
 84	*pubsub.Broker[Session]
 85	db *sql.DB
 86	q  *db.Queries
 87}
 88
 89func (s *service) Create(ctx context.Context, title string) (Session, error) {
 90	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 91		ID:    uuid.New().String(),
 92		Title: title,
 93	})
 94	if err != nil {
 95		return Session{}, err
 96	}
 97	session := s.fromDBItem(dbSession)
 98	s.Publish(pubsub.CreatedEvent, session)
 99	event.SessionCreated()
100	return session, nil
101}
102
103func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
104	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
105		ID:              toolCallID,
106		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
107		Title:           title,
108	})
109	if err != nil {
110		return Session{}, err
111	}
112	session := s.fromDBItem(dbSession)
113	s.Publish(pubsub.CreatedEvent, session)
114	return session, nil
115}
116
117func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
118	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
119		ID:              "title-" + parentSessionID,
120		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
121		Title:           "Generate a title",
122	})
123	if err != nil {
124		return Session{}, err
125	}
126	session := s.fromDBItem(dbSession)
127	s.Publish(pubsub.CreatedEvent, session)
128	return session, nil
129}
130
131func (s *service) Delete(ctx context.Context, id string) error {
132	tx, err := s.db.BeginTx(ctx, nil)
133	if err != nil {
134		return fmt.Errorf("beginning transaction: %w", err)
135	}
136	defer tx.Rollback() //nolint:errcheck
137
138	qtx := s.q.WithTx(tx)
139
140	dbSession, err := qtx.GetSessionByID(ctx, id)
141	if err != nil {
142		return err
143	}
144	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
145		return fmt.Errorf("deleting session messages: %w", err)
146	}
147	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
148		return fmt.Errorf("deleting session files: %w", err)
149	}
150	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
151		return fmt.Errorf("deleting session: %w", err)
152	}
153	if err = tx.Commit(); err != nil {
154		return fmt.Errorf("committing transaction: %w", err)
155	}
156
157	session := s.fromDBItem(dbSession)
158	s.Publish(pubsub.DeletedEvent, session)
159	event.SessionDeleted()
160	return nil
161}
162
163func (s *service) Get(ctx context.Context, id string) (Session, error) {
164	dbSession, err := s.q.GetSessionByID(ctx, id)
165	if err != nil {
166		return Session{}, err
167	}
168	return s.fromDBItem(dbSession), nil
169}
170
171func (s *service) GetLast(ctx context.Context) (Session, error) {
172	dbSession, err := s.q.GetLastSession(ctx)
173	if err != nil {
174		return Session{}, err
175	}
176	return s.fromDBItem(dbSession), nil
177}
178
179func (s *service) Save(ctx context.Context, session Session) (Session, error) {
180	todosJSON, err := marshalTodos(session.Todos)
181	if err != nil {
182		return Session{}, err
183	}
184
185	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
186		ID:               session.ID,
187		Title:            session.Title,
188		PromptTokens:     session.PromptTokens,
189		CompletionTokens: session.CompletionTokens,
190		SummaryMessageID: sql.NullString{
191			String: session.SummaryMessageID,
192			Valid:  session.SummaryMessageID != "",
193		},
194		Cost: session.Cost,
195		Todos: sql.NullString{
196			String: todosJSON,
197			Valid:  todosJSON != "",
198		},
199	})
200	if err != nil {
201		return Session{}, err
202	}
203	estimatedUsage := session.EstimatedUsage
204	session = s.fromDBItem(dbSession)
205	session.EstimatedUsage = estimatedUsage
206	s.Publish(pubsub.UpdatedEvent, session)
207	return session, nil
208}
209
210// UpdateTitleAndUsage updates only the title and usage fields atomically.
211// This is safer than fetching, modifying, and saving the entire session.
212func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
213	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
214		ID:               sessionID,
215		Title:            title,
216		PromptTokens:     promptTokens,
217		CompletionTokens: completionTokens,
218		Cost:             cost,
219	})
220}
221
222// Rename updates only the title of a session without touching updated_at or
223// usage fields.
224func (s *service) Rename(ctx context.Context, id string, title string) error {
225	return s.q.RenameSession(ctx, db.RenameSessionParams{
226		ID:    id,
227		Title: title,
228	})
229}
230
231func (s *service) List(ctx context.Context) ([]Session, error) {
232	dbSessions, err := s.q.ListSessions(ctx)
233	if err != nil {
234		return nil, err
235	}
236	sessions := make([]Session, len(dbSessions))
237	for i, dbSession := range dbSessions {
238		sessions[i] = s.fromDBItem(dbSession)
239	}
240	return sessions, nil
241}
242
243func (s service) fromDBItem(item db.Session) Session {
244	todos, err := unmarshalTodos(item.Todos.String)
245	if err != nil {
246		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
247	}
248	return Session{
249		ID:               item.ID,
250		ParentSessionID:  item.ParentSessionID.String,
251		Title:            item.Title,
252		MessageCount:     item.MessageCount,
253		PromptTokens:     item.PromptTokens,
254		CompletionTokens: item.CompletionTokens,
255		SummaryMessageID: item.SummaryMessageID.String,
256		Cost:             item.Cost,
257		Todos:            todos,
258		CreatedAt:        item.CreatedAt,
259		UpdatedAt:        item.UpdatedAt,
260	}
261}
262
263func marshalTodos(todos []Todo) (string, error) {
264	if len(todos) == 0 {
265		return "", nil
266	}
267	data, err := json.Marshal(todos)
268	if err != nil {
269		return "", err
270	}
271	return string(data), nil
272}
273
274func unmarshalTodos(data string) ([]Todo, error) {
275	if data == "" {
276		return []Todo{}, nil
277	}
278	var todos []Todo
279	if err := json.Unmarshal([]byte(data), &todos); err != nil {
280		return []Todo{}, err
281	}
282	return todos, nil
283}
284
285func NewService(q *db.Queries, conn *sql.DB) Service {
286	broker := pubsub.NewBroker[Session]()
287	return &service{
288		Broker: broker,
289		db:     conn,
290		q:      q,
291	}
292}
293
294// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
295func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
296	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
297}
298
299// ParseAgentToolSessionID parses an agent tool session ID into its components
300func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
301	parts := strings.Split(sessionID, "$$")
302	if len(parts) != 2 {
303		return "", "", false
304	}
305	return parts[0], parts[1], true
306}
307
308// IsAgentToolSession checks if a session ID follows the agent tool session format
309func (s *service) IsAgentToolSession(sessionID string) bool {
310	_, _, ok := s.ParseAgentToolSessionID(sessionID)
311	return ok
312}