session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/google/uuid"
 15	"github.com/zeebo/xxh3"
 16)
 17
 18type TodoStatus string
 19
 20const (
 21	TodoStatusPending    TodoStatus = "pending"
 22	TodoStatusInProgress TodoStatus = "in_progress"
 23	TodoStatusCompleted  TodoStatus = "completed"
 24)
 25
 26// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
 27func HashID(id string) string {
 28	h := xxh3.New()
 29	h.WriteString(id)
 30	return fmt.Sprintf("%x", h.Sum(nil))
 31}
 32
 33type Todo struct {
 34	Content    string     `json:"content"`
 35	Status     TodoStatus `json:"status"`
 36	ActiveForm string     `json:"active_form"`
 37}
 38
 39// HasIncompleteTodos returns true if there are any non-completed todos.
 40func HasIncompleteTodos(todos []Todo) bool {
 41	for _, todo := range todos {
 42		if todo.Status != TodoStatusCompleted {
 43			return true
 44		}
 45	}
 46	return false
 47}
 48
 49type Session struct {
 50	ID               string
 51	ParentSessionID  string
 52	Title            string
 53	MessageCount     int64
 54	PromptTokens     int64
 55	CompletionTokens int64
 56	SummaryMessageID string
 57	Cost             float64
 58	Todos            []Todo
 59	CreatedAt        int64
 60	UpdatedAt        int64
 61}
 62
 63type Service interface {
 64	pubsub.Subscriber[Session]
 65	Create(ctx context.Context, title string) (Session, error)
 66	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 67	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 68	Get(ctx context.Context, id string) (Session, error)
 69	GetLast(ctx context.Context) (Session, error)
 70	List(ctx context.Context) ([]Session, error)
 71	Save(ctx context.Context, session Session) (Session, error)
 72	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 73	Rename(ctx context.Context, id string, title string) error
 74	Delete(ctx context.Context, id string) error
 75
 76	// Agent tool session management
 77	CreateAgentToolSessionID(messageID, toolCallID string) string
 78	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 79	IsAgentToolSession(sessionID string) bool
 80}
 81
 82type service struct {
 83	*pubsub.Broker[Session]
 84	db *sql.DB
 85	q  *db.Queries
 86}
 87
 88func (s *service) Create(ctx context.Context, title string) (Session, error) {
 89	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 90		ID:    uuid.New().String(),
 91		Title: title,
 92	})
 93	if err != nil {
 94		return Session{}, err
 95	}
 96	session := s.fromDBItem(dbSession)
 97	s.Publish(pubsub.CreatedEvent, session)
 98	event.SessionCreated()
 99	return session, nil
100}
101
102func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
103	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
104		ID:              toolCallID,
105		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
106		Title:           title,
107	})
108	if err != nil {
109		return Session{}, err
110	}
111	session := s.fromDBItem(dbSession)
112	s.Publish(pubsub.CreatedEvent, session)
113	return session, nil
114}
115
116func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
117	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
118		ID:              "title-" + parentSessionID,
119		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
120		Title:           "Generate a title",
121	})
122	if err != nil {
123		return Session{}, err
124	}
125	session := s.fromDBItem(dbSession)
126	s.Publish(pubsub.CreatedEvent, session)
127	return session, nil
128}
129
130func (s *service) Delete(ctx context.Context, id string) error {
131	tx, err := s.db.BeginTx(ctx, nil)
132	if err != nil {
133		return fmt.Errorf("beginning transaction: %w", err)
134	}
135	defer tx.Rollback() //nolint:errcheck
136
137	qtx := s.q.WithTx(tx)
138
139	dbSession, err := qtx.GetSessionByID(ctx, id)
140	if err != nil {
141		return err
142	}
143	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
144		return fmt.Errorf("deleting session messages: %w", err)
145	}
146	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
147		return fmt.Errorf("deleting session files: %w", err)
148	}
149	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
150		return fmt.Errorf("deleting session: %w", err)
151	}
152	if err = tx.Commit(); err != nil {
153		return fmt.Errorf("committing transaction: %w", err)
154	}
155
156	session := s.fromDBItem(dbSession)
157	s.Publish(pubsub.DeletedEvent, session)
158	event.SessionDeleted()
159	return nil
160}
161
162func (s *service) Get(ctx context.Context, id string) (Session, error) {
163	dbSession, err := s.q.GetSessionByID(ctx, id)
164	if err != nil {
165		return Session{}, err
166	}
167	return s.fromDBItem(dbSession), nil
168}
169
170func (s *service) GetLast(ctx context.Context) (Session, error) {
171	dbSession, err := s.q.GetLastSession(ctx)
172	if err != nil {
173		return Session{}, err
174	}
175	return s.fromDBItem(dbSession), nil
176}
177
178func (s *service) Save(ctx context.Context, session Session) (Session, error) {
179	todosJSON, err := marshalTodos(session.Todos)
180	if err != nil {
181		return Session{}, err
182	}
183
184	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
185		ID:               session.ID,
186		Title:            session.Title,
187		PromptTokens:     session.PromptTokens,
188		CompletionTokens: session.CompletionTokens,
189		SummaryMessageID: sql.NullString{
190			String: session.SummaryMessageID,
191			Valid:  session.SummaryMessageID != "",
192		},
193		Cost: session.Cost,
194		Todos: sql.NullString{
195			String: todosJSON,
196			Valid:  todosJSON != "",
197		},
198	})
199	if err != nil {
200		return Session{}, err
201	}
202	session = s.fromDBItem(dbSession)
203	s.Publish(pubsub.UpdatedEvent, session)
204	return session, nil
205}
206
207// UpdateTitleAndUsage updates only the title and usage fields atomically.
208// This is safer than fetching, modifying, and saving the entire session.
209func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
210	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
211		ID:               sessionID,
212		Title:            title,
213		PromptTokens:     promptTokens,
214		CompletionTokens: completionTokens,
215		Cost:             cost,
216	})
217}
218
219// Rename updates only the title of a session without touching updated_at or
220// usage fields.
221func (s *service) Rename(ctx context.Context, id string, title string) error {
222	return s.q.RenameSession(ctx, db.RenameSessionParams{
223		ID:    id,
224		Title: title,
225	})
226}
227
228func (s *service) List(ctx context.Context) ([]Session, error) {
229	dbSessions, err := s.q.ListSessions(ctx)
230	if err != nil {
231		return nil, err
232	}
233	sessions := make([]Session, len(dbSessions))
234	for i, dbSession := range dbSessions {
235		sessions[i] = s.fromDBItem(dbSession)
236	}
237	return sessions, nil
238}
239
240func (s service) fromDBItem(item db.Session) Session {
241	todos, err := unmarshalTodos(item.Todos.String)
242	if err != nil {
243		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
244	}
245	return Session{
246		ID:               item.ID,
247		ParentSessionID:  item.ParentSessionID.String,
248		Title:            item.Title,
249		MessageCount:     item.MessageCount,
250		PromptTokens:     item.PromptTokens,
251		CompletionTokens: item.CompletionTokens,
252		SummaryMessageID: item.SummaryMessageID.String,
253		Cost:             item.Cost,
254		Todos:            todos,
255		CreatedAt:        item.CreatedAt,
256		UpdatedAt:        item.UpdatedAt,
257	}
258}
259
260func marshalTodos(todos []Todo) (string, error) {
261	if len(todos) == 0 {
262		return "", nil
263	}
264	data, err := json.Marshal(todos)
265	if err != nil {
266		return "", err
267	}
268	return string(data), nil
269}
270
271func unmarshalTodos(data string) ([]Todo, error) {
272	if data == "" {
273		return []Todo{}, nil
274	}
275	var todos []Todo
276	if err := json.Unmarshal([]byte(data), &todos); err != nil {
277		return []Todo{}, err
278	}
279	return todos, nil
280}
281
282func NewService(q *db.Queries, conn *sql.DB) Service {
283	broker := pubsub.NewBroker[Session]()
284	return &service{
285		Broker: broker,
286		db:     conn,
287		q:      q,
288	}
289}
290
291// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
292func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
293	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
294}
295
296// ParseAgentToolSessionID parses an agent tool session ID into its components
297func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
298	parts := strings.Split(sessionID, "$$")
299	if len(parts) != 2 {
300		return "", "", false
301	}
302	return parts[0], parts[1], true
303}
304
305// IsAgentToolSession checks if a session ID follows the agent tool session format
306func (s *service) IsAgentToolSession(sessionID string) bool {
307	_, _, ok := s.ParseAgentToolSessionID(sessionID)
308	return ok
309}