session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/google/uuid"
 15	"github.com/zeebo/xxh3"
 16)
 17
 18type TodoStatus string
 19
 20const (
 21	TodoStatusPending    TodoStatus = "pending"
 22	TodoStatusInProgress TodoStatus = "in_progress"
 23	TodoStatusCompleted  TodoStatus = "completed"
 24)
 25
 26// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
 27func HashID(id string) string {
 28	h := xxh3.New()
 29	h.WriteString(id)
 30	return fmt.Sprintf("%x", h.Sum(nil))
 31}
 32
 33type Todo struct {
 34	Content    string     `json:"content"`
 35	Status     TodoStatus `json:"status"`
 36	ActiveForm string     `json:"active_form"`
 37}
 38
 39// HasIncompleteTodos returns true if there are any non-completed todos.
 40func HasIncompleteTodos(todos []Todo) bool {
 41	for _, todo := range todos {
 42		if todo.Status != TodoStatusCompleted {
 43			return true
 44		}
 45	}
 46	return false
 47}
 48
 49type Session struct {
 50	ID               string
 51	ParentSessionID  string
 52	Title            string
 53	MessageCount     int64
 54	PromptTokens     int64
 55	CompletionTokens int64
 56	SummaryMessageID string
 57	Cost             float64
 58	Todos            []Todo
 59	CreatedAt        int64
 60	UpdatedAt        int64
 61}
 62
 63type Service interface {
 64	pubsub.Subscriber[Session]
 65	Create(ctx context.Context, title string) (Session, error)
 66	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 67	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 68	Get(ctx context.Context, id string) (Session, error)
 69	List(ctx context.Context) ([]Session, error)
 70	Save(ctx context.Context, session Session) (Session, error)
 71	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 72	Rename(ctx context.Context, id string, title string) error
 73	Delete(ctx context.Context, id string) error
 74
 75	// Agent tool session management
 76	CreateAgentToolSessionID(messageID, toolCallID string) string
 77	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 78	IsAgentToolSession(sessionID string) bool
 79}
 80
 81type service struct {
 82	*pubsub.Broker[Session]
 83	db *sql.DB
 84	q  *db.Queries
 85}
 86
 87func (s *service) Create(ctx context.Context, title string) (Session, error) {
 88	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 89		ID:    uuid.New().String(),
 90		Title: title,
 91	})
 92	if err != nil {
 93		return Session{}, err
 94	}
 95	session := s.fromDBItem(dbSession)
 96	s.Publish(pubsub.CreatedEvent, session)
 97	event.SessionCreated()
 98	return session, nil
 99}
100
101func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
102	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
103		ID:              toolCallID,
104		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
105		Title:           title,
106	})
107	if err != nil {
108		return Session{}, err
109	}
110	session := s.fromDBItem(dbSession)
111	s.Publish(pubsub.CreatedEvent, session)
112	return session, nil
113}
114
115func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
116	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
117		ID:              "title-" + parentSessionID,
118		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
119		Title:           "Generate a title",
120	})
121	if err != nil {
122		return Session{}, err
123	}
124	session := s.fromDBItem(dbSession)
125	s.Publish(pubsub.CreatedEvent, session)
126	return session, nil
127}
128
129func (s *service) Delete(ctx context.Context, id string) error {
130	tx, err := s.db.BeginTx(ctx, nil)
131	if err != nil {
132		return fmt.Errorf("beginning transaction: %w", err)
133	}
134	defer tx.Rollback() //nolint:errcheck
135
136	qtx := s.q.WithTx(tx)
137
138	dbSession, err := qtx.GetSessionByID(ctx, id)
139	if err != nil {
140		return err
141	}
142	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
143		return fmt.Errorf("deleting session messages: %w", err)
144	}
145	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
146		return fmt.Errorf("deleting session files: %w", err)
147	}
148	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
149		return fmt.Errorf("deleting session: %w", err)
150	}
151	if err = tx.Commit(); err != nil {
152		return fmt.Errorf("committing transaction: %w", err)
153	}
154
155	session := s.fromDBItem(dbSession)
156	s.Publish(pubsub.DeletedEvent, session)
157	event.SessionDeleted()
158	return nil
159}
160
161func (s *service) Get(ctx context.Context, id string) (Session, error) {
162	dbSession, err := s.q.GetSessionByID(ctx, id)
163	if err != nil {
164		return Session{}, err
165	}
166	return s.fromDBItem(dbSession), nil
167}
168
169func (s *service) Save(ctx context.Context, session Session) (Session, error) {
170	todosJSON, err := marshalTodos(session.Todos)
171	if err != nil {
172		return Session{}, err
173	}
174
175	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
176		ID:               session.ID,
177		Title:            session.Title,
178		PromptTokens:     session.PromptTokens,
179		CompletionTokens: session.CompletionTokens,
180		SummaryMessageID: sql.NullString{
181			String: session.SummaryMessageID,
182			Valid:  session.SummaryMessageID != "",
183		},
184		Cost: session.Cost,
185		Todos: sql.NullString{
186			String: todosJSON,
187			Valid:  todosJSON != "",
188		},
189	})
190	if err != nil {
191		return Session{}, err
192	}
193	session = s.fromDBItem(dbSession)
194	s.Publish(pubsub.UpdatedEvent, session)
195	return session, nil
196}
197
198// UpdateTitleAndUsage updates only the title and usage fields atomically.
199// This is safer than fetching, modifying, and saving the entire session.
200func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
201	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
202		ID:               sessionID,
203		Title:            title,
204		PromptTokens:     promptTokens,
205		CompletionTokens: completionTokens,
206		Cost:             cost,
207	})
208}
209
210// Rename updates only the title of a session without touching updated_at or
211// usage fields.
212func (s *service) Rename(ctx context.Context, id string, title string) error {
213	return s.q.RenameSession(ctx, db.RenameSessionParams{
214		ID:    id,
215		Title: title,
216	})
217}
218
219func (s *service) List(ctx context.Context) ([]Session, error) {
220	dbSessions, err := s.q.ListSessions(ctx)
221	if err != nil {
222		return nil, err
223	}
224	sessions := make([]Session, len(dbSessions))
225	for i, dbSession := range dbSessions {
226		sessions[i] = s.fromDBItem(dbSession)
227	}
228	return sessions, nil
229}
230
231func (s service) fromDBItem(item db.Session) Session {
232	todos, err := unmarshalTodos(item.Todos.String)
233	if err != nil {
234		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
235	}
236	return Session{
237		ID:               item.ID,
238		ParentSessionID:  item.ParentSessionID.String,
239		Title:            item.Title,
240		MessageCount:     item.MessageCount,
241		PromptTokens:     item.PromptTokens,
242		CompletionTokens: item.CompletionTokens,
243		SummaryMessageID: item.SummaryMessageID.String,
244		Cost:             item.Cost,
245		Todos:            todos,
246		CreatedAt:        item.CreatedAt,
247		UpdatedAt:        item.UpdatedAt,
248	}
249}
250
251func marshalTodos(todos []Todo) (string, error) {
252	if len(todos) == 0 {
253		return "", nil
254	}
255	data, err := json.Marshal(todos)
256	if err != nil {
257		return "", err
258	}
259	return string(data), nil
260}
261
262func unmarshalTodos(data string) ([]Todo, error) {
263	if data == "" {
264		return []Todo{}, nil
265	}
266	var todos []Todo
267	if err := json.Unmarshal([]byte(data), &todos); err != nil {
268		return []Todo{}, err
269	}
270	return todos, nil
271}
272
273func NewService(q *db.Queries, conn *sql.DB) Service {
274	broker := pubsub.NewBroker[Session]()
275	return &service{
276		Broker: broker,
277		db:     conn,
278		q:      q,
279	}
280}
281
282// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
283func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
284	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
285}
286
287// ParseAgentToolSessionID parses an agent tool session ID into its components
288func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
289	parts := strings.Split(sessionID, "$$")
290	if len(parts) != 2 {
291		return "", "", false
292	}
293	return parts[0], parts[1], true
294}
295
296// IsAgentToolSession checks if a session ID follows the agent tool session format
297func (s *service) IsAgentToolSession(sessionID string) bool {
298	_, _, ok := s.ParseAgentToolSessionID(sessionID)
299	return ok
300}