session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/db"
 13	"github.com/charmbracelet/crush/internal/event"
 14	"github.com/charmbracelet/crush/internal/pubsub"
 15	"github.com/google/uuid"
 16	"github.com/zeebo/xxh3"
 17)
 18
 19type TodoStatus string
 20
 21const (
 22	TodoStatusPending    TodoStatus = "pending"
 23	TodoStatusInProgress TodoStatus = "in_progress"
 24	TodoStatusCompleted  TodoStatus = "completed"
 25)
 26
 27// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
 28func HashID(id string) string {
 29	h := xxh3.New()
 30	h.WriteString(id)
 31	return fmt.Sprintf("%x", h.Sum(nil))
 32}
 33
 34type Todo struct {
 35	Content    string     `json:"content"`
 36	Status     TodoStatus `json:"status"`
 37	ActiveForm string     `json:"active_form"`
 38}
 39
 40// HasIncompleteTodos returns true if there are any non-completed todos.
 41func HasIncompleteTodos(todos []Todo) bool {
 42	for _, todo := range todos {
 43		if todo.Status != TodoStatusCompleted {
 44			return true
 45		}
 46	}
 47	return false
 48}
 49
 50type Session struct {
 51	ID               string
 52	ParentSessionID  string
 53	Title            string
 54	MessageCount     int64
 55	PromptTokens     int64
 56	CompletionTokens int64
 57	SummaryMessageID string
 58	Cost             float64
 59	Todos            []Todo
 60	Models           map[config.SelectedModelType]config.SelectedModel
 61	CreatedAt        int64
 62	UpdatedAt        int64
 63}
 64
 65type Service interface {
 66	pubsub.Subscriber[Session]
 67	Create(ctx context.Context, title string) (Session, error)
 68	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 69	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 70	Get(ctx context.Context, id string) (Session, error)
 71	GetLast(ctx context.Context) (Session, error)
 72	List(ctx context.Context) ([]Session, error)
 73	Save(ctx context.Context, session Session) (Session, error)
 74	SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error)
 75	UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error
 76	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 77	Rename(ctx context.Context, id string, title string) error
 78	Delete(ctx context.Context, id string) error
 79
 80	// Agent tool session management
 81	CreateAgentToolSessionID(messageID, toolCallID string) string
 82	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 83	IsAgentToolSession(sessionID string) bool
 84}
 85
 86type service struct {
 87	*pubsub.Broker[Session]
 88	db *sql.DB
 89	q  *db.Queries
 90}
 91
 92func (s *service) Create(ctx context.Context, title string) (Session, error) {
 93	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 94		ID:    uuid.New().String(),
 95		Title: title,
 96	})
 97	if err != nil {
 98		return Session{}, err
 99	}
100	session := s.fromDBItem(dbSession)
101	s.Publish(pubsub.CreatedEvent, session)
102	event.SessionCreated()
103	return session, nil
104}
105
106func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
107	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
108		ID:              toolCallID,
109		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
110		Title:           title,
111	})
112	if err != nil {
113		return Session{}, err
114	}
115	session := s.fromDBItem(dbSession)
116	s.Publish(pubsub.CreatedEvent, session)
117	return session, nil
118}
119
120func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
121	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
122		ID:              "title-" + parentSessionID,
123		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
124		Title:           "Generate a title",
125	})
126	if err != nil {
127		return Session{}, err
128	}
129	session := s.fromDBItem(dbSession)
130	s.Publish(pubsub.CreatedEvent, session)
131	return session, nil
132}
133
134func (s *service) Delete(ctx context.Context, id string) error {
135	tx, err := s.db.BeginTx(ctx, nil)
136	if err != nil {
137		return fmt.Errorf("beginning transaction: %w", err)
138	}
139	defer tx.Rollback() //nolint:errcheck
140
141	qtx := s.q.WithTx(tx)
142
143	dbSession, err := qtx.GetSessionByID(ctx, id)
144	if err != nil {
145		return err
146	}
147	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
148		return fmt.Errorf("deleting session messages: %w", err)
149	}
150	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
151		return fmt.Errorf("deleting session files: %w", err)
152	}
153	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
154		return fmt.Errorf("deleting session: %w", err)
155	}
156	if err = tx.Commit(); err != nil {
157		return fmt.Errorf("committing transaction: %w", err)
158	}
159
160	session := s.fromDBItem(dbSession)
161	s.Publish(pubsub.DeletedEvent, session)
162	event.SessionDeleted()
163	return nil
164}
165
166func (s *service) Get(ctx context.Context, id string) (Session, error) {
167	dbSession, err := s.q.GetSessionByID(ctx, id)
168	if err != nil {
169		return Session{}, err
170	}
171	return s.fromDBItem(dbSession), nil
172}
173
174func (s *service) GetLast(ctx context.Context) (Session, error) {
175	dbSession, err := s.q.GetLastSession(ctx)
176	if err != nil {
177		return Session{}, err
178	}
179	return s.fromDBItem(dbSession), nil
180}
181
182func (s *service) Save(ctx context.Context, session Session) (Session, error) {
183	todosJSON, err := marshalTodos(session.Todos)
184	if err != nil {
185		return Session{}, err
186	}
187
188	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
189		ID:               session.ID,
190		Title:            session.Title,
191		PromptTokens:     session.PromptTokens,
192		CompletionTokens: session.CompletionTokens,
193		SummaryMessageID: sql.NullString{
194			String: session.SummaryMessageID,
195			Valid:  session.SummaryMessageID != "",
196		},
197		Cost: session.Cost,
198		Todos: sql.NullString{
199			String: todosJSON,
200			Valid:  todosJSON != "",
201		},
202	})
203	if err != nil {
204		return Session{}, err
205	}
206	session = s.fromDBItem(dbSession)
207	s.Publish(pubsub.UpdatedEvent, session)
208	return session, nil
209}
210
211// UpdateTitleAndUsage updates only the title and usage fields atomically.
212// This is safer than fetching, modifying, and saving the entire session.
213func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
214	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
215		ID:               sessionID,
216		Title:            title,
217		PromptTokens:     promptTokens,
218		CompletionTokens: completionTokens,
219		Cost:             cost,
220	})
221}
222
223// Rename updates only the title of a session without touching updated_at or
224// usage fields.
225func (s *service) Rename(ctx context.Context, id string, title string) error {
226	return s.q.RenameSession(ctx, db.RenameSessionParams{
227		ID:    id,
228		Title: title,
229	})
230}
231
232func (s *service) List(ctx context.Context) ([]Session, error) {
233	dbSessions, err := s.q.ListSessions(ctx)
234	if err != nil {
235		return nil, err
236	}
237	sessions := make([]Session, len(dbSessions))
238	for i, dbSession := range dbSessions {
239		sessions[i] = s.fromDBItem(dbSession)
240	}
241	return sessions, nil
242}
243
244func (s service) fromDBItem(item db.Session) Session {
245	todos, err := unmarshalTodos(item.Todos.String)
246	if err != nil {
247		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
248	}
249	models, err := unmarshalModels(item.Models.String)
250	if err != nil {
251		slog.Error("Failed to unmarshal models", "session_id", item.ID, "error", err)
252	}
253	return Session{
254		ID:               item.ID,
255		ParentSessionID:  item.ParentSessionID.String,
256		Title:            item.Title,
257		MessageCount:     item.MessageCount,
258		PromptTokens:     item.PromptTokens,
259		CompletionTokens: item.CompletionTokens,
260		SummaryMessageID: item.SummaryMessageID.String,
261		Cost:             item.Cost,
262		Todos:            todos,
263		Models:           models,
264		CreatedAt:        item.CreatedAt,
265		UpdatedAt:        item.UpdatedAt,
266	}
267}
268
269func marshalTodos(todos []Todo) (string, error) {
270	if len(todos) == 0 {
271		return "", nil
272	}
273	data, err := json.Marshal(todos)
274	if err != nil {
275		return "", err
276	}
277	return string(data), nil
278}
279
280func unmarshalTodos(data string) ([]Todo, error) {
281	if data == "" {
282		return []Todo{}, nil
283	}
284	var todos []Todo
285	if err := json.Unmarshal([]byte(data), &todos); err != nil {
286		return []Todo{}, err
287	}
288	return todos, nil
289}
290
291func marshalModels(models map[config.SelectedModelType]config.SelectedModel) (string, error) {
292	if len(models) == 0 {
293		return "", nil
294	}
295	data, err := json.Marshal(models)
296	if err != nil {
297		return "", err
298	}
299	return string(data), nil
300}
301
302func unmarshalModels(data string) (map[config.SelectedModelType]config.SelectedModel, error) {
303	if data == "" {
304		return nil, nil
305	}
306	var models map[config.SelectedModelType]config.SelectedModel
307	if err := json.Unmarshal([]byte(data), &models); err != nil {
308		return nil, err
309	}
310	return models, nil
311}
312
313func (s *service) UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error {
314	modelsJSON, err := marshalModels(models)
315	if err != nil {
316		return fmt.Errorf("failed to marshal models: %w", err)
317	}
318	_, err = s.q.UpdateSessionModels(ctx, db.UpdateSessionModelsParams{
319		Models: sql.NullString{String: modelsJSON, Valid: modelsJSON != ""},
320		ID:     id,
321	})
322	return err
323}
324
325// SaveWithModels saves the session and then persists the models column as a
326// second operation. This is intentionally non-atomic: if the models update
327// fails, the session fields are still saved (which is equivalent to the
328// pre-feature behavior where models were never persisted). The next agent turn
329// will retry the models write, so transient failures are self-healing.
330func (s *service) SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) {
331	saved, err := s.Save(ctx, session)
332	if err != nil {
333		return Session{}, err
334	}
335	if err := s.UpdateSessionModels(ctx, session.ID, models); err != nil {
336		return Session{}, fmt.Errorf("failed to persist models: %w", err)
337	}
338	return saved, nil
339}
340
341func NewService(q *db.Queries, conn *sql.DB) Service {
342	broker := pubsub.NewBroker[Session]()
343	return &service{
344		Broker: broker,
345		db:     conn,
346		q:      q,
347	}
348}
349
350// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
351func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
352	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
353}
354
355// ParseAgentToolSessionID parses an agent tool session ID into its components
356func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
357	parts := strings.Split(sessionID, "$$")
358	if len(parts) != 2 {
359		return "", "", false
360	}
361	return parts[0], parts[1], true
362}
363
364// IsAgentToolSession checks if a session ID follows the agent tool session format
365func (s *service) IsAgentToolSession(sessionID string) bool {
366	_, _, ok := s.ParseAgentToolSessionID(sessionID)
367	return ok
368}